Skip to content

Commit 68ac966

Browse files
authored
Merge pull request #259 from DevanshuNEU/fix/auth-hardening
fix: auth hardening -- domain exceptions + null user_id safety (OPE-76, OPE-77)
2 parents f211731 + 38e0e40 commit 68ac966

6 files changed

Lines changed: 192 additions & 54 deletions

File tree

backend/dependencies.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Shared dependencies and service instances.
33
All route modules import from here to avoid circular imports.
44
"""
5+
from typing import Optional
56
from fastapi import HTTPException, Depends
67

78
from services.indexer_optimized import OptimizedCodeIndexer
@@ -43,21 +44,25 @@
4344
redis_client = cache.redis if cache.redis else None
4445

4546

46-
def get_repo_or_404(repo_id: str, user_id: str) -> dict:
47+
def get_repo_or_404(repo_id: str, user_id: Optional[str]) -> dict:
4748
"""
4849
Get repository with ownership verification.
49-
Returns 404 if not found or user doesn't own it.
50+
Raises 401 if user_id is None, 404 if not found or user doesn't own it.
5051
"""
52+
if user_id is None:
53+
raise HTTPException(status_code=401, detail="User ID required for this operation")
5154
repo = repo_manager.get_repo_for_user(repo_id, user_id)
5255
if not repo:
5356
raise HTTPException(status_code=404, detail="Repository not found")
5457
return repo
5558

5659

57-
def verify_repo_access(repo_id: str, user_id: str) -> None:
60+
def verify_repo_access(repo_id: str, user_id: Optional[str]) -> None:
5861
"""
5962
Verify user has access to repository.
60-
Raises 404 if no access.
63+
Raises 401 if user_id is None, 404 if no access.
6164
"""
65+
if user_id is None:
66+
raise HTTPException(status_code=401, detail="User ID required for this operation")
6267
if not repo_manager.verify_ownership(repo_id, user_id):
6368
raise HTTPException(status_code=404, detail="Repository not found")

backend/middleware/auth.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ async def demo_search(auth: AuthContext = Depends(public_auth)):
2424
import os
2525
import hashlib
2626

27+
from services.exceptions import (
28+
AuthenticationError,
29+
TokenExpiredError,
30+
InvalidTokenError,
31+
TokenMissingClaimError,
32+
)
33+
2734
from fastapi import Depends, HTTPException, status
2835
from fastapi.security import HTTPBearer
2936
from fastapi.security.http import HTTPAuthorizationCredentials
@@ -59,7 +66,11 @@ def identifier(self) -> str:
5966
# Core validation functions
6067

6168
def _validate_jwt(token: str) -> Optional[AuthContext]:
62-
"""Validate Supabase JWT token"""
69+
"""Validate Supabase JWT token.
70+
71+
Returns AuthContext on success, None if token isn't a JWT (allows API key fallback).
72+
Raises HTTPException for specific JWT failures (expired, bad audience, missing claim).
73+
"""
6374
try:
6475
from services.auth import get_auth_service
6576
auth_service = get_auth_service()
@@ -70,6 +81,15 @@ def _validate_jwt(token: str) -> Optional[AuthContext]:
7081
email=user.get("email"),
7182
tier=user.get("metadata", {}).get("tier", "free")
7283
)
84+
except TokenExpiredError:
85+
raise HTTPException(status_code=401, detail="Token expired")
86+
except TokenMissingClaimError as e:
87+
raise HTTPException(status_code=401, detail=str(e))
88+
except InvalidTokenError:
89+
# Could be a non-JWT token (API key) -- allow fallback
90+
return None
91+
except AuthenticationError as e:
92+
raise HTTPException(status_code=401, detail=str(e))
7393
except Exception:
7494
return None
7595

@@ -181,7 +201,12 @@ async def get_current_user(
181201
"""
182202
from services.auth import get_auth_service
183203
auth_service = get_auth_service()
184-
return auth_service.verify_jwt(credentials.credentials)
204+
try:
205+
return auth_service.verify_jwt(credentials.credentials)
206+
except (TokenExpiredError, TokenMissingClaimError) as e:
207+
raise HTTPException(status_code=401, detail=str(e))
208+
except AuthenticationError:
209+
raise HTTPException(status_code=401, detail="Invalid or expired token")
185210

186211

187212
async def get_optional_user(

backend/services/auth.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
from fastapi import HTTPException, status
66
from typing import Optional, Dict, Any
77
import os
8-
import jwt
8+
import jwt as pyjwt
99
from datetime import datetime
1010
from supabase import create_client, Client
1111

1212
from services.observability import logger
13+
from services.exceptions import (
14+
AuthenticationError,
15+
TokenExpiredError,
16+
InvalidTokenError,
17+
TokenMissingClaimError,
18+
)
1319

1420

1521
class SupabaseAuthService:
@@ -48,7 +54,7 @@ def verify_jwt(self, token: str) -> Dict[str, Any]:
4854
def _verify_local(self, token: str) -> Dict[str, Any]:
4955
"""Decode and verify JWT locally with HS256 secret."""
5056
try:
51-
payload = jwt.decode(
57+
payload = pyjwt.decode(
5258
token,
5359
self.jwt_secret,
5460
algorithms=["HS256"],
@@ -58,58 +64,40 @@ def _verify_local(self, token: str) -> Dict[str, Any]:
5864

5965
user_id = payload.get("sub")
6066
if not user_id:
61-
raise HTTPException(
62-
status_code=status.HTTP_401_UNAUTHORIZED,
63-
detail="Token missing subject claim",
64-
)
67+
raise TokenMissingClaimError("sub")
6568

6669
return {
6770
"user_id": user_id,
6871
"email": payload.get("email"),
6972
"metadata": payload.get("user_metadata") or {},
7073
}
7174

72-
except jwt.ExpiredSignatureError:
73-
raise HTTPException(
74-
status_code=status.HTTP_401_UNAUTHORIZED,
75-
detail="Token expired",
76-
)
77-
except jwt.InvalidAudienceError:
78-
raise HTTPException(
79-
status_code=status.HTTP_401_UNAUTHORIZED,
80-
detail="Invalid token audience",
81-
)
82-
except jwt.InvalidTokenError as e:
75+
except pyjwt.ExpiredSignatureError as e:
76+
raise TokenExpiredError("Token expired") from e
77+
except pyjwt.InvalidAudienceError as e:
78+
raise InvalidTokenError("Invalid token audience") from e
79+
except pyjwt.InvalidTokenError as e:
8380
logger.debug("JWT decode failed", error=str(e))
84-
raise HTTPException(
85-
status_code=status.HTTP_401_UNAUTHORIZED,
86-
detail="Invalid token",
87-
)
81+
raise InvalidTokenError("Invalid token") from e
8882

8983
def _verify_via_api(self, token: str) -> Dict[str, Any]:
9084
"""Fallback: verify via Supabase API call (requires network)."""
9185
try:
9286
response = self.client.auth.get_user(token)
9387

9488
if not response.user:
95-
raise HTTPException(
96-
status_code=status.HTTP_401_UNAUTHORIZED,
97-
detail="Invalid or expired token",
98-
)
89+
raise InvalidTokenError("Invalid or expired token")
9990

10091
return {
10192
"user_id": response.user.id,
10293
"email": response.user.email,
10394
"metadata": response.user.user_metadata or {},
10495
}
105-
except HTTPException:
96+
except AuthenticationError:
10697
raise
10798
except Exception as e:
10899
logger.debug("API-based JWT verification failed", error=str(e))
109-
raise HTTPException(
110-
status_code=status.HTTP_401_UNAUTHORIZED,
111-
detail="Token verification failed",
112-
)
100+
raise AuthenticationError("Token verification failed") from e
113101

114102
async def signup(self, email: str, password: str, github_username: Optional[str] = None) -> Dict[str, Any]:
115103
"""

backend/services/exceptions.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
Domain exceptions for authentication.
3+
4+
Services raise these; the middleware/route layer translates them to HTTP responses.
5+
This decouples business logic from the HTTP framework.
6+
"""
7+
8+
9+
class AuthenticationError(Exception):
10+
"""Base auth error. All auth exceptions inherit from this."""
11+
pass
12+
13+
14+
class TokenExpiredError(AuthenticationError):
15+
"""JWT token has expired."""
16+
pass
17+
18+
19+
class InvalidTokenError(AuthenticationError):
20+
"""JWT token is malformed, wrong signature, or invalid audience."""
21+
pass
22+
23+
24+
class TokenMissingClaimError(AuthenticationError):
25+
"""JWT token is missing a required claim (e.g., sub)."""
26+
27+
def __init__(self, claim: str = "sub"):
28+
super().__init__(f"Token missing required claim: {claim}")
29+
self.claim = claim
30+
31+
32+
class InvalidCredentialsError(AuthenticationError):
33+
"""Login failed due to wrong email/password."""
34+
pass
35+
36+
37+
class SignupError(AuthenticationError):
38+
"""User registration failed."""
39+
pass
40+
41+
42+
class SessionError(AuthenticationError):
43+
"""Token refresh or logout failed."""
44+
pass
45+
46+
47+
class UserIdRequiredError(AuthenticationError):
48+
"""Operation requires a user_id but auth context has None (e.g., API key without user)."""
49+
pass
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Tests for auth hardening -- domain exceptions + null safety (OPE-76, OPE-77)."""
2+
import pytest
3+
from unittest.mock import patch, MagicMock
4+
from fastapi import HTTPException
5+
6+
7+
class TestNullUserIdSafety:
8+
"""API key users (no user_id) get 401, not confusing 404."""
9+
10+
def test_search_with_null_user_id_returns_401(self, client, valid_headers):
11+
"""Search should reject None user_id via the real verify_repo_access null guard."""
12+
from fastapi.testclient import TestClient
13+
from main import app
14+
from middleware.auth import require_auth, AuthContext
15+
16+
# Mock auth to return a context with no user_id (API key user)
17+
async def mock_auth_no_user():
18+
return AuthContext(api_key_name="test-key", user_id=None)
19+
20+
app.dependency_overrides[require_auth] = mock_auth_no_user
21+
try:
22+
resp = client.post(
23+
"/api/v1/search",
24+
json={"query": "auth", "repo_id": "test"},
25+
headers=valid_headers,
26+
)
27+
# verify_repo_access in dependencies.py should catch None user_id
28+
assert resp.status_code == 401
29+
assert "User ID required" in resp.json()["detail"]
30+
finally:
31+
app.dependency_overrides.pop(require_auth, None)
32+
33+
def test_get_repo_or_404_rejects_none_user_id(self):
34+
"""get_repo_or_404 should raise 401 when user_id is None."""
35+
from dependencies import get_repo_or_404
36+
from fastapi import HTTPException
37+
38+
with pytest.raises(HTTPException) as exc:
39+
get_repo_or_404("some-repo", None)
40+
assert exc.value.status_code == 401
41+
42+
def test_verify_repo_access_rejects_none_user_id(self):
43+
"""verify_repo_access should raise 401 when user_id is None."""
44+
from dependencies import verify_repo_access
45+
from fastapi import HTTPException
46+
47+
with pytest.raises(HTTPException) as exc:
48+
verify_repo_access("some-repo", None)
49+
assert exc.value.status_code == 401
50+
51+
52+
class TestDomainExceptions:
53+
"""Auth service raises domain exceptions, not HTTPException."""
54+
55+
def test_expired_token_raises_domain_exception(self):
56+
"""Auth service should raise TokenExpiredError, not HTTPException."""
57+
from services.exceptions import TokenExpiredError
58+
assert issubclass(TokenExpiredError, Exception)
59+
assert not issubclass(TokenExpiredError, HTTPException)
60+
61+
def test_exception_hierarchy(self):
62+
"""All auth exceptions inherit from AuthenticationError."""
63+
from services.exceptions import (
64+
AuthenticationError,
65+
TokenExpiredError,
66+
InvalidTokenError,
67+
TokenMissingClaimError,
68+
InvalidCredentialsError,
69+
SignupError,
70+
SessionError,
71+
UserIdRequiredError,
72+
)
73+
for exc_class in [
74+
TokenExpiredError, InvalidTokenError, TokenMissingClaimError,
75+
InvalidCredentialsError, SignupError, SessionError, UserIdRequiredError,
76+
]:
77+
assert issubclass(exc_class, AuthenticationError)

backend/tests/test_jwt_local_decode.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,40 +50,34 @@ def test_bearer_prefix_stripped(self, auth_service):
5050

5151
assert result["user_id"] == "user-456"
5252

53-
def test_expired_token_raises_401(self, auth_service):
53+
def test_expired_token_raises_error(self, auth_service):
5454
token = _make_token({"sub": "user-789", "exp": int(time.time()) - 60})
5555

56-
from fastapi import HTTPException
57-
with pytest.raises(HTTPException) as exc:
56+
from services.exceptions import TokenExpiredError
57+
with pytest.raises(TokenExpiredError):
5858
auth_service.verify_jwt(token)
59-
assert exc.value.status_code == 401
60-
assert "expired" in exc.value.detail.lower()
6159

62-
def test_wrong_secret_raises_401(self, auth_service):
60+
def test_wrong_secret_raises_error(self, auth_service):
6361
token = _make_token({"sub": "user-000"}, secret="wrong-secret")
6462

65-
from fastapi import HTTPException
66-
with pytest.raises(HTTPException) as exc:
63+
from services.exceptions import InvalidTokenError
64+
with pytest.raises(InvalidTokenError):
6765
auth_service.verify_jwt(token)
68-
assert exc.value.status_code == 401
6966

70-
def test_missing_sub_claim_raises_401(self, auth_service):
67+
def test_missing_sub_claim_raises_error(self, auth_service):
7168
token = _make_token({"email": "no-sub@test.com"})
7269

73-
from fastapi import HTTPException
74-
with pytest.raises(HTTPException) as exc:
70+
from services.exceptions import TokenMissingClaimError
71+
with pytest.raises(TokenMissingClaimError):
7572
auth_service.verify_jwt(token)
76-
assert exc.value.status_code == 401
77-
assert "subject" in exc.value.detail.lower()
7873

79-
def test_wrong_audience_raises_401(self, auth_service):
74+
def test_wrong_audience_raises_error(self, auth_service):
8075
payload = {"sub": "user-aud", "aud": "wrong-audience", "exp": int(time.time()) + 3600}
8176
token = pyjwt.encode(payload, JWT_SECRET, algorithm="HS256")
8277

83-
from fastapi import HTTPException
84-
with pytest.raises(HTTPException) as exc:
78+
from services.exceptions import InvalidTokenError
79+
with pytest.raises(InvalidTokenError):
8580
auth_service.verify_jwt(token)
86-
assert exc.value.status_code == 401
8781

8882
def test_no_network_call_made(self, auth_service):
8983
"""The whole point of OPE-75: verify_jwt should NOT hit the network."""

0 commit comments

Comments
 (0)