diff --git a/backend/dependencies.py b/backend/dependencies.py index 6891d67..45956b9 100644 --- a/backend/dependencies.py +++ b/backend/dependencies.py @@ -2,6 +2,7 @@ Shared dependencies and service instances. All route modules import from here to avoid circular imports. """ +from typing import Optional from fastapi import HTTPException, Depends from services.indexer_optimized import OptimizedCodeIndexer @@ -43,21 +44,25 @@ redis_client = cache.redis if cache.redis else None -def get_repo_or_404(repo_id: str, user_id: str) -> dict: +def get_repo_or_404(repo_id: str, user_id: Optional[str]) -> dict: """ Get repository with ownership verification. - Returns 404 if not found or user doesn't own it. + Raises 401 if user_id is None, 404 if not found or user doesn't own it. """ + if user_id is None: + raise HTTPException(status_code=401, detail="User ID required for this operation") repo = repo_manager.get_repo_for_user(repo_id, user_id) if not repo: raise HTTPException(status_code=404, detail="Repository not found") return repo -def verify_repo_access(repo_id: str, user_id: str) -> None: +def verify_repo_access(repo_id: str, user_id: Optional[str]) -> None: """ Verify user has access to repository. - Raises 404 if no access. + Raises 401 if user_id is None, 404 if no access. """ + if user_id is None: + raise HTTPException(status_code=401, detail="User ID required for this operation") if not repo_manager.verify_ownership(repo_id, user_id): raise HTTPException(status_code=404, detail="Repository not found") diff --git a/backend/middleware/auth.py b/backend/middleware/auth.py index 511fa40..68c8964 100644 --- a/backend/middleware/auth.py +++ b/backend/middleware/auth.py @@ -24,6 +24,13 @@ async def demo_search(auth: AuthContext = Depends(public_auth)): import os import hashlib +from services.exceptions import ( + AuthenticationError, + TokenExpiredError, + InvalidTokenError, + TokenMissingClaimError, +) + from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer from fastapi.security.http import HTTPAuthorizationCredentials @@ -59,7 +66,11 @@ def identifier(self) -> str: # Core validation functions def _validate_jwt(token: str) -> Optional[AuthContext]: - """Validate Supabase JWT token""" + """Validate Supabase JWT token. + + Returns AuthContext on success, None if token isn't a JWT (allows API key fallback). + Raises HTTPException for specific JWT failures (expired, bad audience, missing claim). + """ try: from services.auth import get_auth_service auth_service = get_auth_service() @@ -70,6 +81,15 @@ def _validate_jwt(token: str) -> Optional[AuthContext]: email=user.get("email"), tier=user.get("metadata", {}).get("tier", "free") ) + except TokenExpiredError: + raise HTTPException(status_code=401, detail="Token expired") + except TokenMissingClaimError as e: + raise HTTPException(status_code=401, detail=str(e)) + except InvalidTokenError: + # Could be a non-JWT token (API key) -- allow fallback + return None + except AuthenticationError as e: + raise HTTPException(status_code=401, detail=str(e)) except Exception: return None @@ -181,7 +201,12 @@ async def get_current_user( """ from services.auth import get_auth_service auth_service = get_auth_service() - return auth_service.verify_jwt(credentials.credentials) + try: + return auth_service.verify_jwt(credentials.credentials) + except (TokenExpiredError, TokenMissingClaimError) as e: + raise HTTPException(status_code=401, detail=str(e)) + except AuthenticationError: + raise HTTPException(status_code=401, detail="Invalid or expired token") async def get_optional_user( diff --git a/backend/services/auth.py b/backend/services/auth.py index a6bd688..b0af2de 100644 --- a/backend/services/auth.py +++ b/backend/services/auth.py @@ -5,11 +5,17 @@ from fastapi import HTTPException, status from typing import Optional, Dict, Any import os -import jwt +import jwt as pyjwt from datetime import datetime from supabase import create_client, Client from services.observability import logger +from services.exceptions import ( + AuthenticationError, + TokenExpiredError, + InvalidTokenError, + TokenMissingClaimError, +) class SupabaseAuthService: @@ -48,7 +54,7 @@ def verify_jwt(self, token: str) -> Dict[str, Any]: def _verify_local(self, token: str) -> Dict[str, Any]: """Decode and verify JWT locally with HS256 secret.""" try: - payload = jwt.decode( + payload = pyjwt.decode( token, self.jwt_secret, algorithms=["HS256"], @@ -58,10 +64,7 @@ def _verify_local(self, token: str) -> Dict[str, Any]: user_id = payload.get("sub") if not user_id: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token missing subject claim", - ) + raise TokenMissingClaimError("sub") return { "user_id": user_id, @@ -69,22 +72,13 @@ def _verify_local(self, token: str) -> Dict[str, Any]: "metadata": payload.get("user_metadata") or {}, } - except jwt.ExpiredSignatureError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token expired", - ) - except jwt.InvalidAudienceError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token audience", - ) - except jwt.InvalidTokenError as e: + except pyjwt.ExpiredSignatureError as e: + raise TokenExpiredError("Token expired") from e + except pyjwt.InvalidAudienceError as e: + raise InvalidTokenError("Invalid token audience") from e + except pyjwt.InvalidTokenError as e: logger.debug("JWT decode failed", error=str(e)) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token", - ) + raise InvalidTokenError("Invalid token") from e def _verify_via_api(self, token: str) -> Dict[str, Any]: """Fallback: verify via Supabase API call (requires network).""" @@ -92,24 +86,18 @@ def _verify_via_api(self, token: str) -> Dict[str, Any]: response = self.client.auth.get_user(token) if not response.user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - ) + raise InvalidTokenError("Invalid or expired token") return { "user_id": response.user.id, "email": response.user.email, "metadata": response.user.user_metadata or {}, } - except HTTPException: + except AuthenticationError: raise except Exception as e: logger.debug("API-based JWT verification failed", error=str(e)) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token verification failed", - ) + raise AuthenticationError("Token verification failed") from e async def signup(self, email: str, password: str, github_username: Optional[str] = None) -> Dict[str, Any]: """ diff --git a/backend/services/exceptions.py b/backend/services/exceptions.py new file mode 100644 index 0000000..cd20b6d --- /dev/null +++ b/backend/services/exceptions.py @@ -0,0 +1,49 @@ +""" +Domain exceptions for authentication. + +Services raise these; the middleware/route layer translates them to HTTP responses. +This decouples business logic from the HTTP framework. +""" + + +class AuthenticationError(Exception): + """Base auth error. All auth exceptions inherit from this.""" + pass + + +class TokenExpiredError(AuthenticationError): + """JWT token has expired.""" + pass + + +class InvalidTokenError(AuthenticationError): + """JWT token is malformed, wrong signature, or invalid audience.""" + pass + + +class TokenMissingClaimError(AuthenticationError): + """JWT token is missing a required claim (e.g., sub).""" + + def __init__(self, claim: str = "sub"): + super().__init__(f"Token missing required claim: {claim}") + self.claim = claim + + +class InvalidCredentialsError(AuthenticationError): + """Login failed due to wrong email/password.""" + pass + + +class SignupError(AuthenticationError): + """User registration failed.""" + pass + + +class SessionError(AuthenticationError): + """Token refresh or logout failed.""" + pass + + +class UserIdRequiredError(AuthenticationError): + """Operation requires a user_id but auth context has None (e.g., API key without user).""" + pass diff --git a/backend/tests/test_auth_hardening.py b/backend/tests/test_auth_hardening.py new file mode 100644 index 0000000..f4b14c3 --- /dev/null +++ b/backend/tests/test_auth_hardening.py @@ -0,0 +1,77 @@ +"""Tests for auth hardening -- domain exceptions + null safety (OPE-76, OPE-77).""" +import pytest +from unittest.mock import patch, MagicMock +from fastapi import HTTPException + + +class TestNullUserIdSafety: + """API key users (no user_id) get 401, not confusing 404.""" + + def test_search_with_null_user_id_returns_401(self, client, valid_headers): + """Search should reject None user_id via the real verify_repo_access null guard.""" + from fastapi.testclient import TestClient + from main import app + from middleware.auth import require_auth, AuthContext + + # Mock auth to return a context with no user_id (API key user) + async def mock_auth_no_user(): + return AuthContext(api_key_name="test-key", user_id=None) + + app.dependency_overrides[require_auth] = mock_auth_no_user + try: + resp = client.post( + "/api/v1/search", + json={"query": "auth", "repo_id": "test"}, + headers=valid_headers, + ) + # verify_repo_access in dependencies.py should catch None user_id + assert resp.status_code == 401 + assert "User ID required" in resp.json()["detail"] + finally: + app.dependency_overrides.pop(require_auth, None) + + def test_get_repo_or_404_rejects_none_user_id(self): + """get_repo_or_404 should raise 401 when user_id is None.""" + from dependencies import get_repo_or_404 + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc: + get_repo_or_404("some-repo", None) + assert exc.value.status_code == 401 + + def test_verify_repo_access_rejects_none_user_id(self): + """verify_repo_access should raise 401 when user_id is None.""" + from dependencies import verify_repo_access + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc: + verify_repo_access("some-repo", None) + assert exc.value.status_code == 401 + + +class TestDomainExceptions: + """Auth service raises domain exceptions, not HTTPException.""" + + def test_expired_token_raises_domain_exception(self): + """Auth service should raise TokenExpiredError, not HTTPException.""" + from services.exceptions import TokenExpiredError + assert issubclass(TokenExpiredError, Exception) + assert not issubclass(TokenExpiredError, HTTPException) + + def test_exception_hierarchy(self): + """All auth exceptions inherit from AuthenticationError.""" + from services.exceptions import ( + AuthenticationError, + TokenExpiredError, + InvalidTokenError, + TokenMissingClaimError, + InvalidCredentialsError, + SignupError, + SessionError, + UserIdRequiredError, + ) + for exc_class in [ + TokenExpiredError, InvalidTokenError, TokenMissingClaimError, + InvalidCredentialsError, SignupError, SessionError, UserIdRequiredError, + ]: + assert issubclass(exc_class, AuthenticationError) diff --git a/backend/tests/test_jwt_local_decode.py b/backend/tests/test_jwt_local_decode.py index 07e08f0..c525e5f 100644 --- a/backend/tests/test_jwt_local_decode.py +++ b/backend/tests/test_jwt_local_decode.py @@ -50,40 +50,34 @@ def test_bearer_prefix_stripped(self, auth_service): assert result["user_id"] == "user-456" - def test_expired_token_raises_401(self, auth_service): + def test_expired_token_raises_error(self, auth_service): token = _make_token({"sub": "user-789", "exp": int(time.time()) - 60}) - from fastapi import HTTPException - with pytest.raises(HTTPException) as exc: + from services.exceptions import TokenExpiredError + with pytest.raises(TokenExpiredError): auth_service.verify_jwt(token) - assert exc.value.status_code == 401 - assert "expired" in exc.value.detail.lower() - def test_wrong_secret_raises_401(self, auth_service): + def test_wrong_secret_raises_error(self, auth_service): token = _make_token({"sub": "user-000"}, secret="wrong-secret") - from fastapi import HTTPException - with pytest.raises(HTTPException) as exc: + from services.exceptions import InvalidTokenError + with pytest.raises(InvalidTokenError): auth_service.verify_jwt(token) - assert exc.value.status_code == 401 - def test_missing_sub_claim_raises_401(self, auth_service): + def test_missing_sub_claim_raises_error(self, auth_service): token = _make_token({"email": "no-sub@test.com"}) - from fastapi import HTTPException - with pytest.raises(HTTPException) as exc: + from services.exceptions import TokenMissingClaimError + with pytest.raises(TokenMissingClaimError): auth_service.verify_jwt(token) - assert exc.value.status_code == 401 - assert "subject" in exc.value.detail.lower() - def test_wrong_audience_raises_401(self, auth_service): + def test_wrong_audience_raises_error(self, auth_service): payload = {"sub": "user-aud", "aud": "wrong-audience", "exp": int(time.time()) + 3600} token = pyjwt.encode(payload, JWT_SECRET, algorithm="HS256") - from fastapi import HTTPException - with pytest.raises(HTTPException) as exc: + from services.exceptions import InvalidTokenError + with pytest.raises(InvalidTokenError): auth_service.verify_jwt(token) - assert exc.value.status_code == 401 def test_no_network_call_made(self, auth_service): """The whole point of OPE-75: verify_jwt should NOT hit the network."""