Skip to content

Commit 1039d34

Browse files
committed
fix(security): add JWT authentication to WebSocket endpoint
- Add authenticate_websocket() helper for token validation via query param - Reject connections without token (4001) or with invalid token (4001) - Validate repo exists before accepting connection (4004) - Add TODO for repo ownership validation (needs user_id column) - Add unit tests for WebSocket authentication - Configure pytest for async tests Fixes #6
1 parent aa403e7 commit 1039d34

3 files changed

Lines changed: 135 additions & 8 deletions

File tree

backend/main.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,19 +180,56 @@ async def add_repository(
180180
raise HTTPException(status_code=400, detail=str(e))
181181

182182

183+
async def authenticate_websocket(websocket: WebSocket) -> Optional[dict]:
184+
"""
185+
Authenticate WebSocket connection via query parameter token.
186+
187+
WebSockets can't use Authorization headers during handshake,
188+
so we pass the JWT token as a query parameter instead.
189+
190+
Returns:
191+
User dict if authenticated, None otherwise (connection closed with error)
192+
"""
193+
token = websocket.query_params.get("token")
194+
if not token:
195+
await websocket.close(code=4001, reason="Missing authentication token")
196+
return None
197+
198+
try:
199+
from services.auth import get_auth_service
200+
auth_service = get_auth_service()
201+
return auth_service.verify_jwt(token)
202+
except Exception:
203+
await websocket.close(code=4001, reason="Invalid or expired token")
204+
return None
205+
206+
183207
@app.websocket("/ws/index/{repo_id}")
184208
async def websocket_index(websocket: WebSocket, repo_id: str):
185-
"""Real-time indexing with progress updates"""
209+
"""
210+
Real-time repository indexing with progress updates.
211+
212+
Requires JWT token passed as query parameter: ?token=<jwt>
213+
Sends progress updates via JSON messages during indexing.
214+
"""
215+
# Authenticate before accepting connection
216+
user = await authenticate_websocket(websocket)
217+
if not user:
218+
return
219+
220+
# TODO: Add repo ownership validation once user_id column exists in repos table
221+
# For now, any authenticated user can index any repo they know the ID of
222+
223+
# Validate repo exists before accepting connection
224+
repo = repo_manager.get_repo(repo_id)
225+
if not repo:
226+
await websocket.close(code=4004, reason="Repository not found")
227+
return
228+
229+
# Connection authenticated and repo valid - accept
186230
await websocket.accept()
187231

188232
try:
189-
# Get repo info
190-
repo = repo_manager.get_repo(repo_id)
191-
if not repo:
192-
await websocket.send_json({"error": "Repository not found"})
193-
await websocket.close()
194-
return
195-
196233
repo_manager.update_status(repo_id, "indexing")
197234

198235
# Index with progress callback

backend/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ testpaths = tests
33
python_files = test_*.py
44
python_classes = Test*
55
python_functions = test_*
6+
asyncio_mode = auto
67
addopts =
78
-v
89
--tb=short
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
WebSocket Authentication Tests
3+
Tests for issue #6: Secure WebSocket endpoints
4+
"""
5+
import pytest
6+
from unittest.mock import MagicMock, AsyncMock, patch
7+
8+
9+
class TestWebSocketAuthentication:
10+
"""Integration tests for WebSocket authentication via query parameter token"""
11+
12+
def test_websocket_rejects_missing_token(self, client):
13+
"""WebSocket should reject connections without token (4001)"""
14+
with pytest.raises(Exception):
15+
with client.websocket_connect("/ws/index/test-repo-id"):
16+
pass
17+
18+
def test_websocket_rejects_invalid_token(self, client):
19+
"""WebSocket should reject connections with invalid token (4001)"""
20+
with pytest.raises(Exception):
21+
with client.websocket_connect("/ws/index/test-repo-id?token=invalid-token"):
22+
pass
23+
24+
def test_websocket_rejects_nonexistent_repo(self, client):
25+
"""WebSocket should reject if repo doesn't exist (4004)"""
26+
with patch('main.authenticate_websocket') as mock_auth:
27+
mock_auth.return_value = {"user_id": "test-user", "email": "test@example.com"}
28+
29+
with pytest.raises(Exception):
30+
with client.websocket_connect("/ws/index/nonexistent-repo?token=valid"):
31+
pass
32+
33+
34+
class TestAuthenticateWebsocketFunction:
35+
"""Unit tests for the authenticate_websocket helper"""
36+
37+
@pytest.mark.asyncio
38+
async def test_returns_none_without_token(self):
39+
"""Should return None and close connection if no token provided"""
40+
from main import authenticate_websocket
41+
42+
mock_ws = MagicMock()
43+
mock_ws.query_params = {}
44+
mock_ws.close = AsyncMock()
45+
46+
result = await authenticate_websocket(mock_ws)
47+
48+
assert result is None
49+
mock_ws.close.assert_called_once_with(code=4001, reason="Missing authentication token")
50+
51+
@pytest.mark.asyncio
52+
async def test_returns_none_with_invalid_token(self):
53+
"""Should return None and close connection if token is invalid"""
54+
from main import authenticate_websocket
55+
56+
mock_ws = MagicMock()
57+
mock_ws.query_params = {"token": "invalid-token"}
58+
mock_ws.close = AsyncMock()
59+
60+
with patch('services.auth.get_auth_service') as mock_get_service:
61+
mock_service = MagicMock()
62+
mock_service.verify_jwt.side_effect = Exception("Invalid token")
63+
mock_get_service.return_value = mock_service
64+
65+
result = await authenticate_websocket(mock_ws)
66+
67+
assert result is None
68+
mock_ws.close.assert_called_once_with(code=4001, reason="Invalid or expired token")
69+
70+
@pytest.mark.asyncio
71+
async def test_returns_user_with_valid_token(self):
72+
"""Should return user dict if token is valid"""
73+
from main import authenticate_websocket
74+
75+
mock_ws = MagicMock()
76+
mock_ws.query_params = {"token": "valid-jwt-token"}
77+
mock_ws.close = AsyncMock()
78+
79+
expected_user = {"user_id": "user-123", "email": "test@example.com"}
80+
81+
with patch('services.auth.get_auth_service') as mock_get_service:
82+
mock_service = MagicMock()
83+
mock_service.verify_jwt.return_value = expected_user
84+
mock_get_service.return_value = mock_service
85+
86+
result = await authenticate_websocket(mock_ws)
87+
88+
assert result == expected_user
89+
mock_ws.close.assert_not_called()

0 commit comments

Comments
 (0)