Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions backend/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Shared dependencies and service instances.
All route modules import from here to avoid circular imports.
"""
from typing import Optional
from fastapi import HTTPException, Depends

from services.indexer_optimized import OptimizedCodeIndexer
Expand Down Expand Up @@ -43,21 +44,25 @@
redis_client = cache.redis if cache.redis else None


def get_repo_or_404(repo_id: str, user_id: str) -> dict:
def get_repo_or_404(repo_id: str, user_id: Optional[str]) -> dict:
"""
Get repository with ownership verification.
Returns 404 if not found or user doesn't own it.
Raises 401 if user_id is None, 404 if not found or user doesn't own it.
"""
if user_id is None:
raise HTTPException(status_code=401, detail="User ID required for this operation")
repo = repo_manager.get_repo_for_user(repo_id, user_id)
if not repo:
raise HTTPException(status_code=404, detail="Repository not found")
return repo


def verify_repo_access(repo_id: str, user_id: str) -> None:
def verify_repo_access(repo_id: str, user_id: Optional[str]) -> None:
"""
Verify user has access to repository.
Raises 404 if no access.
Raises 401 if user_id is None, 404 if no access.
"""
if user_id is None:
raise HTTPException(status_code=401, detail="User ID required for this operation")
if not repo_manager.verify_ownership(repo_id, user_id):
raise HTTPException(status_code=404, detail="Repository not found")
29 changes: 27 additions & 2 deletions backend/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ async def demo_search(auth: AuthContext = Depends(public_auth)):
import os
import hashlib

from services.exceptions import (
AuthenticationError,
TokenExpiredError,
InvalidTokenError,
TokenMissingClaimError,
)

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer
from fastapi.security.http import HTTPAuthorizationCredentials
Expand Down Expand Up @@ -59,7 +66,11 @@ def identifier(self) -> str:
# Core validation functions

def _validate_jwt(token: str) -> Optional[AuthContext]:
"""Validate Supabase JWT token"""
"""Validate Supabase JWT token.

Returns AuthContext on success, None if token isn't a JWT (allows API key fallback).
Raises HTTPException for specific JWT failures (expired, bad audience, missing claim).
"""
try:
from services.auth import get_auth_service
auth_service = get_auth_service()
Expand All @@ -70,6 +81,15 @@ def _validate_jwt(token: str) -> Optional[AuthContext]:
email=user.get("email"),
tier=user.get("metadata", {}).get("tier", "free")
)
except TokenExpiredError:
raise HTTPException(status_code=401, detail="Token expired")
except TokenMissingClaimError as e:
raise HTTPException(status_code=401, detail=str(e))
except InvalidTokenError:
# Could be a non-JWT token (API key) -- allow fallback
return None
except AuthenticationError as e:
raise HTTPException(status_code=401, detail=str(e))
except Exception:
return None

Expand Down Expand Up @@ -181,7 +201,12 @@ async def get_current_user(
"""
from services.auth import get_auth_service
auth_service = get_auth_service()
return auth_service.verify_jwt(credentials.credentials)
try:
return auth_service.verify_jwt(credentials.credentials)
except (TokenExpiredError, TokenMissingClaimError) as e:
raise HTTPException(status_code=401, detail=str(e))
except AuthenticationError:
raise HTTPException(status_code=401, detail="Invalid or expired token")


async def get_optional_user(
Expand Down
48 changes: 18 additions & 30 deletions backend/services/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
from fastapi import HTTPException, status
from typing import Optional, Dict, Any
import os
import jwt
import jwt as pyjwt
from datetime import datetime
from supabase import create_client, Client

from services.observability import logger
from services.exceptions import (
AuthenticationError,
TokenExpiredError,
InvalidTokenError,
TokenMissingClaimError,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


class SupabaseAuthService:
Expand Down Expand Up @@ -48,7 +54,7 @@ def verify_jwt(self, token: str) -> Dict[str, Any]:
def _verify_local(self, token: str) -> Dict[str, Any]:
"""Decode and verify JWT locally with HS256 secret."""
try:
payload = jwt.decode(
payload = pyjwt.decode(
token,
self.jwt_secret,
algorithms=["HS256"],
Expand All @@ -58,58 +64,40 @@ def _verify_local(self, token: str) -> Dict[str, Any]:

user_id = payload.get("sub")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token missing subject claim",
)
raise TokenMissingClaimError("sub")

return {
"user_id": user_id,
"email": payload.get("email"),
"metadata": payload.get("user_metadata") or {},
}

except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
)
except jwt.InvalidAudienceError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
)
except jwt.InvalidTokenError as e:
except pyjwt.ExpiredSignatureError as e:
raise TokenExpiredError("Token expired") from e
except pyjwt.InvalidAudienceError as e:
raise InvalidTokenError("Invalid token audience") from e
except pyjwt.InvalidTokenError as e:
logger.debug("JWT decode failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
)
raise InvalidTokenError("Invalid token") from e

def _verify_via_api(self, token: str) -> Dict[str, Any]:
"""Fallback: verify via Supabase API call (requires network)."""
try:
response = self.client.auth.get_user(token)

if not response.user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
raise InvalidTokenError("Invalid or expired token")

return {
"user_id": response.user.id,
"email": response.user.email,
"metadata": response.user.user_metadata or {},
}
except HTTPException:
except AuthenticationError:
raise
except Exception as e:
logger.debug("API-based JWT verification failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token verification failed",
)
raise AuthenticationError("Token verification failed") from e

Comment thread
DevanshuNEU marked this conversation as resolved.
async def signup(self, email: str, password: str, github_username: Optional[str] = None) -> Dict[str, Any]:
"""
Expand Down
49 changes: 49 additions & 0 deletions backend/services/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Domain exceptions for authentication.

Services raise these; the middleware/route layer translates them to HTTP responses.
This decouples business logic from the HTTP framework.
"""


class AuthenticationError(Exception):
"""Base auth error. All auth exceptions inherit from this."""
pass


class TokenExpiredError(AuthenticationError):
"""JWT token has expired."""
pass


class InvalidTokenError(AuthenticationError):
"""JWT token is malformed, wrong signature, or invalid audience."""
pass


class TokenMissingClaimError(AuthenticationError):
"""JWT token is missing a required claim (e.g., sub)."""

def __init__(self, claim: str = "sub"):
super().__init__(f"Token missing required claim: {claim}")
self.claim = claim


class InvalidCredentialsError(AuthenticationError):
"""Login failed due to wrong email/password."""
pass


class SignupError(AuthenticationError):
"""User registration failed."""
pass


class SessionError(AuthenticationError):
"""Token refresh or logout failed."""
pass


class UserIdRequiredError(AuthenticationError):
"""Operation requires a user_id but auth context has None (e.g., API key without user)."""
pass
77 changes: 77 additions & 0 deletions backend/tests/test_auth_hardening.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tests for auth hardening -- domain exceptions + null safety (OPE-76, OPE-77)."""
import pytest
from unittest.mock import patch, MagicMock
from fastapi import HTTPException


class TestNullUserIdSafety:
"""API key users (no user_id) get 401, not confusing 404."""

def test_search_with_null_user_id_returns_401(self, client, valid_headers):
"""Search should reject None user_id via the real verify_repo_access null guard."""
from fastapi.testclient import TestClient
from main import app
from middleware.auth import require_auth, AuthContext

# Mock auth to return a context with no user_id (API key user)
async def mock_auth_no_user():
return AuthContext(api_key_name="test-key", user_id=None)

app.dependency_overrides[require_auth] = mock_auth_no_user
try:
resp = client.post(
"/api/v1/search",
json={"query": "auth", "repo_id": "test"},
headers=valid_headers,
)
# verify_repo_access in dependencies.py should catch None user_id
assert resp.status_code == 401
assert "User ID required" in resp.json()["detail"]
finally:
app.dependency_overrides.pop(require_auth, None)

def test_get_repo_or_404_rejects_none_user_id(self):
"""get_repo_or_404 should raise 401 when user_id is None."""
from dependencies import get_repo_or_404
from fastapi import HTTPException

with pytest.raises(HTTPException) as exc:
get_repo_or_404("some-repo", None)
assert exc.value.status_code == 401

def test_verify_repo_access_rejects_none_user_id(self):
"""verify_repo_access should raise 401 when user_id is None."""
from dependencies import verify_repo_access
from fastapi import HTTPException

with pytest.raises(HTTPException) as exc:
verify_repo_access("some-repo", None)
assert exc.value.status_code == 401


class TestDomainExceptions:
"""Auth service raises domain exceptions, not HTTPException."""

def test_expired_token_raises_domain_exception(self):
"""Auth service should raise TokenExpiredError, not HTTPException."""
from services.exceptions import TokenExpiredError
assert issubclass(TokenExpiredError, Exception)
assert not issubclass(TokenExpiredError, HTTPException)

def test_exception_hierarchy(self):
"""All auth exceptions inherit from AuthenticationError."""
from services.exceptions import (
AuthenticationError,
TokenExpiredError,
InvalidTokenError,
TokenMissingClaimError,
InvalidCredentialsError,
SignupError,
SessionError,
UserIdRequiredError,
)
for exc_class in [
TokenExpiredError, InvalidTokenError, TokenMissingClaimError,
InvalidCredentialsError, SignupError, SessionError, UserIdRequiredError,
]:
assert issubclass(exc_class, AuthenticationError)
30 changes: 12 additions & 18 deletions backend/tests/test_jwt_local_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,40 +50,34 @@ def test_bearer_prefix_stripped(self, auth_service):

assert result["user_id"] == "user-456"

def test_expired_token_raises_401(self, auth_service):
def test_expired_token_raises_error(self, auth_service):
token = _make_token({"sub": "user-789", "exp": int(time.time()) - 60})

from fastapi import HTTPException
with pytest.raises(HTTPException) as exc:
from services.exceptions import TokenExpiredError
with pytest.raises(TokenExpiredError):
auth_service.verify_jwt(token)
assert exc.value.status_code == 401
assert "expired" in exc.value.detail.lower()

def test_wrong_secret_raises_401(self, auth_service):
def test_wrong_secret_raises_error(self, auth_service):
token = _make_token({"sub": "user-000"}, secret="wrong-secret")

from fastapi import HTTPException
with pytest.raises(HTTPException) as exc:
from services.exceptions import InvalidTokenError
with pytest.raises(InvalidTokenError):
auth_service.verify_jwt(token)
assert exc.value.status_code == 401

def test_missing_sub_claim_raises_401(self, auth_service):
def test_missing_sub_claim_raises_error(self, auth_service):
token = _make_token({"email": "no-sub@test.com"})

from fastapi import HTTPException
with pytest.raises(HTTPException) as exc:
from services.exceptions import TokenMissingClaimError
with pytest.raises(TokenMissingClaimError):
auth_service.verify_jwt(token)
assert exc.value.status_code == 401
assert "subject" in exc.value.detail.lower()

def test_wrong_audience_raises_401(self, auth_service):
def test_wrong_audience_raises_error(self, auth_service):
payload = {"sub": "user-aud", "aud": "wrong-audience", "exp": int(time.time()) + 3600}
token = pyjwt.encode(payload, JWT_SECRET, algorithm="HS256")

from fastapi import HTTPException
with pytest.raises(HTTPException) as exc:
from services.exceptions import InvalidTokenError
with pytest.raises(InvalidTokenError):
auth_service.verify_jwt(token)
assert exc.value.status_code == 401

def test_no_network_call_made(self, auth_service):
"""The whole point of OPE-75: verify_jwt should NOT hit the network."""
Expand Down