Skip to content

Commit bc7b587

Browse files
committed
fix(ws): critical bugs in WebSocket endpoint
- Fix: Must accept() before close() - was causing runtime errors - Fix: Check job exists before subscribing (return 4404) - Fix: Handle race condition - job may complete before WS connects - Fix: Remove unused INITIAL_TIMEOUT_SECONDS constant - Add: Idle timeout (120s) to prevent zombie connections - Add: 6 new tests for edge cases (13 total) Review by CTO caught these before production 🙏
1 parent de89248 commit bc7b587

2 files changed

Lines changed: 208 additions & 35 deletions

File tree

backend/routes/ws_playground.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
a smooth streaming experience instead of polling every 2 seconds.
66
77
Channel format: job:{job_id}:events
8-
Message types: started, progress, completed, error
8+
Message types: connected, cloning, progress, completed, error
99
"""
1010
import json
1111
import asyncio
@@ -15,13 +15,14 @@
1515

1616
from dependencies import redis_client
1717
from services.observability import logger
18+
from services.anonymous_indexer import AnonymousIndexingJob
1819

1920

20-
# How long to wait for first message before giving up
21-
INITIAL_TIMEOUT_SECONDS = 30
21+
# How long between messages before sending a ping
22+
PING_INTERVAL_SECONDS = 30
2223

23-
# How long between messages before assuming job is dead
24-
MESSAGE_TIMEOUT_SECONDS = 60
24+
# How long to wait for any activity before closing
25+
IDLE_TIMEOUT_SECONDS = 120
2526

2627

2728
async def websocket_playground_index(websocket: WebSocket, job_id: str):
@@ -35,59 +36,109 @@ async def websocket_playground_index(websocket: WebSocket, job_id: str):
3536
No auth required - job_id is an unguessable UUID that acts as
3637
a bearer token. Only the session that created the job knows it.
3738
"""
39+
# Validate job_id format (basic sanity check)
40+
if not job_id or len(job_id) < 10:
41+
# Must accept before we can close with a reason
42+
await websocket.accept()
43+
await websocket.close(code=4400, reason="Invalid job ID")
44+
return
45+
3846
# Validate we have Redis (required for pub/sub)
3947
if not redis_client:
4048
logger.error("WebSocket failed - no Redis connection")
49+
await websocket.accept()
4150
await websocket.close(code=4500, reason="Service unavailable")
4251
return
4352

44-
# Validate job_id format (basic sanity check)
45-
if not job_id or len(job_id) < 10:
46-
await websocket.close(code=4400, reason="Invalid job ID")
47-
return
53+
# Check if job exists before subscribing
54+
job_manager = AnonymousIndexingJob(redis_client)
55+
job = job_manager.get_job(job_id)
4856

49-
channel = f"job:{job_id}:events"
57+
if not job:
58+
await websocket.accept()
59+
await websocket.close(code=4404, reason="Job not found")
60+
return
5061

5162
# Accept the WebSocket connection
5263
await websocket.accept()
53-
logger.info("WebSocket connected", job_id=job_id[:12], channel=channel)
64+
logger.info("WebSocket connected", job_id=job_id[:12])
5465

55-
# Set up Redis pub/sub
66+
# Handle race condition: job might already be complete
67+
job_status = job.get("status")
68+
if job_status == "completed":
69+
await websocket.send_json({
70+
"type": "completed",
71+
"job_id": job_id,
72+
"repo_id": job.get("repo_id"),
73+
"stats": job.get("stats"),
74+
"message": "Indexing already complete"
75+
})
76+
await websocket.close()
77+
return
78+
elif job_status == "failed":
79+
await websocket.send_json({
80+
"type": "error",
81+
"job_id": job_id,
82+
"error": job.get("error"),
83+
"message": job.get("error_message", "Indexing failed"),
84+
"recoverable": False
85+
})
86+
await websocket.close()
87+
return
88+
89+
channel = f"job:{job_id}:events"
5690
pubsub = redis_client.pubsub()
5791

5892
try:
5993
# Subscribe to job's event channel
6094
await asyncio.to_thread(pubsub.subscribe, channel)
6195
logger.debug("Subscribed to channel", channel=channel)
6296

63-
# Send initial ack so client knows we're connected
97+
# Send initial ack with current state
6498
await websocket.send_json({
6599
"type": "connected",
66100
"job_id": job_id,
101+
"current_status": job_status,
67102
"message": "Listening for indexing events"
68103
})
69104

70105
# Listen for messages
106+
last_activity = asyncio.get_event_loop().time()
107+
71108
while True:
72-
# Check for new message (non-blocking with timeout)
109+
current_time = asyncio.get_event_loop().time()
110+
111+
# Check for idle timeout
112+
if current_time - last_activity > IDLE_TIMEOUT_SECONDS:
113+
logger.warning("WebSocket idle timeout", job_id=job_id[:12])
114+
await websocket.send_json({
115+
"type": "error",
116+
"message": "Connection timed out - no activity"
117+
})
118+
break
119+
120+
# Check for new message (non-blocking with short timeout)
73121
message = await asyncio.to_thread(
74122
pubsub.get_message,
75123
ignore_subscribe_messages=True,
76-
timeout=MESSAGE_TIMEOUT_SECONDS
124+
timeout=PING_INTERVAL_SECONDS
77125
)
78126

79127
if message is None:
80-
# Timeout - check if client is still connected
128+
# No message - send ping to keep connection alive
81129
try:
82130
await websocket.send_json({"type": "ping"})
83131
except Exception:
84-
logger.debug("Client disconnected during timeout", job_id=job_id[:12])
132+
logger.debug("Client disconnected during ping", job_id=job_id[:12])
85133
break
86134
continue
87135

88136
if message["type"] != "message":
89137
continue
90138

139+
# Got a message - reset activity timer
140+
last_activity = current_time
141+
91142
# Parse and forward the event
92143
try:
93144
event_data = json.loads(message["data"])

backend/tests/test_ws_playground.py

Lines changed: 140 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
from fastapi.testclient import TestClient
1111
from fastapi.websockets import WebSocket
1212

13-
# We'll test the WebSocket handler directly since TestClient
14-
# doesn't support async WebSocket testing well
15-
1613

1714
class TestWebSocketPlayground:
1815
"""Test suite for playground WebSocket endpoint."""
@@ -23,27 +20,87 @@ def test_websocket_rejects_invalid_job_id(self):
2320

2421
client = TestClient(app)
2522

26-
# Job ID too short
27-
with pytest.raises(Exception):
28-
with client.websocket_connect("/api/v1/ws/playground/abc"):
29-
pass
23+
# Job ID too short - should accept then close with 4400
24+
with client.websocket_connect("/api/v1/ws/playground/abc") as ws:
25+
# Connection will be closed immediately
26+
pass
3027

3128
def test_websocket_rejects_when_no_redis(self):
3229
"""Should close gracefully if Redis is unavailable."""
3330
from routes.ws_playground import websocket_playground_index
3431

35-
# Mock WebSocket
3632
mock_ws = AsyncMock(spec=WebSocket)
3733

38-
# Patch redis_client to None
3934
with patch('routes.ws_playground.redis_client', None):
4035
asyncio.run(websocket_playground_index(mock_ws, "idx_abc123def456"))
4136

42-
# Should close with service unavailable
37+
# Should accept first, then close
38+
mock_ws.accept.assert_called_once()
4339
mock_ws.close.assert_called_once()
4440
call_args = mock_ws.close.call_args
4541
assert call_args[1]['code'] == 4500
4642

43+
def test_websocket_rejects_nonexistent_job(self):
44+
"""Should close with 4404 if job doesn't exist."""
45+
from routes.ws_playground import websocket_playground_index
46+
47+
mock_ws = AsyncMock(spec=WebSocket)
48+
mock_redis = MagicMock()
49+
mock_redis.get.return_value = None # Job not found
50+
51+
with patch('routes.ws_playground.redis_client', mock_redis):
52+
asyncio.run(websocket_playground_index(mock_ws, "idx_nonexistent"))
53+
54+
mock_ws.accept.assert_called_once()
55+
mock_ws.close.assert_called_once()
56+
call_args = mock_ws.close.call_args
57+
assert call_args[1]['code'] == 4404
58+
59+
def test_websocket_handles_already_completed_job(self):
60+
"""If job is already complete, send completion and close."""
61+
from routes.ws_playground import websocket_playground_index
62+
63+
mock_ws = AsyncMock(spec=WebSocket)
64+
mock_redis = MagicMock()
65+
mock_redis.get.return_value = json.dumps({
66+
"job_id": "idx_test123",
67+
"status": "completed",
68+
"repo_id": "anon_test123",
69+
"stats": {"files_indexed": 100, "functions_found": 500}
70+
})
71+
72+
with patch('routes.ws_playground.redis_client', mock_redis):
73+
asyncio.run(websocket_playground_index(mock_ws, "idx_test123"))
74+
75+
# Should send completed event immediately
76+
mock_ws.send_json.assert_called_once()
77+
sent_data = mock_ws.send_json.call_args[0][0]
78+
assert sent_data["type"] == "completed"
79+
assert sent_data["repo_id"] == "anon_test123"
80+
mock_ws.close.assert_called_once()
81+
82+
def test_websocket_handles_already_failed_job(self):
83+
"""If job already failed, send error and close."""
84+
from routes.ws_playground import websocket_playground_index
85+
86+
mock_ws = AsyncMock(spec=WebSocket)
87+
mock_redis = MagicMock()
88+
mock_redis.get.return_value = json.dumps({
89+
"job_id": "idx_test123",
90+
"status": "failed",
91+
"error": "clone_failed",
92+
"error_message": "Repository not found"
93+
})
94+
95+
with patch('routes.ws_playground.redis_client', mock_redis):
96+
asyncio.run(websocket_playground_index(mock_ws, "idx_test123"))
97+
98+
mock_ws.send_json.assert_called_once()
99+
sent_data = mock_ws.send_json.call_args[0][0]
100+
assert sent_data["type"] == "error"
101+
assert sent_data["error"] == "clone_failed"
102+
mock_ws.close.assert_called_once()
103+
47104

48105
class TestPubSubIntegration:
49106
"""Test Redis Pub/Sub event publishing."""
@@ -52,7 +109,6 @@ def test_publish_event_called_on_status_update(self):
52109
"""Verify events are published when status changes."""
53110
from services.anonymous_indexer import AnonymousIndexingJob, JobStatus
54111

55-
# Create mock Redis
56112
mock_redis = MagicMock()
57113
mock_redis.get.return_value = json.dumps({
58114
"job_id": "idx_test123",
@@ -61,11 +117,8 @@ def test_publish_event_called_on_status_update(self):
61117
})
62118

63119
job_manager = AnonymousIndexingJob(mock_redis)
64-
65-
# Update status to cloning
66120
job_manager.update_status("idx_test123", JobStatus.CLONING)
67121

68-
# Verify publish was called
69122
mock_redis.publish.assert_called()
70123
call_args = mock_redis.publish.call_args
71124

@@ -88,7 +141,6 @@ def test_progress_event_published_with_file_info(self):
88141

89142
job_manager = AnonymousIndexingJob(mock_redis)
90143

91-
# Update progress
92144
job_manager.update_progress(
93145
job_id="idx_test123",
94146
files_processed=25,
@@ -132,7 +184,6 @@ def test_completed_event_includes_stats(self):
132184
repo_id="anon_test123"
133185
)
134186

135-
# Verify publish was called with completed event
136187
call_args = mock_redis.publish.call_args
137188
event_data = json.loads(call_args[0][1])
138189

@@ -152,13 +203,38 @@ def test_processing_status_skips_duplicate_publish(self):
152203

153204
job_manager = AnonymousIndexingJob(mock_redis)
154205

155-
# Call update_status with PROCESSING directly
156206
progress = JobProgress(files_total=100, files_processed=0, functions_found=0)
157207
job_manager.update_status("idx_test123", JobStatus.PROCESSING, progress=progress)
158208

159209
# Should NOT have published (returns early for PROCESSING)
160210
mock_redis.publish.assert_not_called()
161211

212+
def test_failed_event_includes_recoverable_flag(self):
213+
"""Failed events should indicate if error is recoverable."""
214+
from services.anonymous_indexer import AnonymousIndexingJob, JobStatus
215+
216+
mock_redis = MagicMock()
217+
mock_redis.get.return_value = json.dumps({
218+
"job_id": "idx_test123",
219+
"status": "processing"
220+
})
221+
222+
job_manager = AnonymousIndexingJob(mock_redis)
223+
224+
# Clone failures are not recoverable
225+
job_manager.update_status(
226+
"idx_test123",
227+
JobStatus.FAILED,
228+
error="clone_failed",
229+
error_message="Repo not found"
230+
)
231+
232+
call_args = mock_redis.publish.call_args
233+
event_data = json.loads(call_args[0][1])
234+
235+
assert event_data["type"] == "failed"
236+
assert event_data["recoverable"] == False
237+
162238

163239
class TestEventFormats:
164240
"""Verify event message formats match frontend expectations."""
@@ -179,6 +255,52 @@ def test_event_types_are_strings(self):
179255
call_args = mock_redis.publish.call_args
180256
event_data = json.loads(call_args[0][1])
181257

182-
# Type should be string "cloning", not JobStatus.CLONING
183258
assert isinstance(event_data["type"], str)
184259
assert event_data["type"] == "cloning"
260+
261+
def test_progress_percent_is_integer(self):
262+
"""Progress percent should be an integer 0-100."""
263+
from services.anonymous_indexer import AnonymousIndexingJob
264+
265+
mock_redis = MagicMock()
266+
mock_redis.get.return_value = json.dumps({
267+
"job_id": "idx_test123",
268+
"status": "processing"
269+
})
270+
271+
job_manager = AnonymousIndexingJob(mock_redis)
272+
273+
job_manager.update_progress(
274+
job_id="idx_test123",
275+
files_processed=33,
276+
functions_found=100,
277+
files_total=100,
278+
current_file="test.py"
279+
)
280+
281+
progress_call = mock_redis.publish.call_args_list[0]
282+
event_data = json.loads(progress_call[0][1])
283+
284+
assert isinstance(event_data["percent"], int)
285+
assert event_data["percent"] == 33
286+
287+
def test_all_events_include_job_id(self):
288+
"""All events should include job_id for client correlation."""
289+
from services.anonymous_indexer import AnonymousIndexingJob, JobStatus
290+
291+
mock_redis = MagicMock()
292+
mock_redis.get.return_value = json.dumps({
293+
"job_id": "idx_test123",
294+
"status": "queued"
295+
})
296+
297+
job_manager = AnonymousIndexingJob(mock_redis)
298+
299+
# Test different event types
300+
job_manager.update_status("idx_test123", JobStatus.CLONING)
301+
302+
call_args = mock_redis.publish.call_args
303+
event_data = json.loads(call_args[0][1])
304+
305+
assert "job_id" in event_data
306+
assert event_data["job_id"] == "idx_test123"

0 commit comments

Comments
 (0)