Skip to content

Commit de89248

Browse files
committed
feat(backend): WebSocket endpoint for real-time indexing progress (#149)
- Add /api/v1/ws/playground/{job_id} WebSocket endpoint - Implement Redis Pub/Sub for real-time event streaming - Publish events: connected, cloning, progress, completed, error - Progress events include current_file for streaming file list - Skip duplicate PROCESSING events (handled by update_progress) - Add comprehensive test suite (7 tests) Closes #149
1 parent a35a035 commit de89248

4 files changed

Lines changed: 394 additions & 3 deletions

File tree

backend/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from routes.api_keys import router as api_keys_router
2828
from routes.users import router as users_router
2929
from routes.search_v2 import router as search_v2_router
30+
from routes.ws_playground import websocket_playground_index
3031

3132

3233
# Lifespan context manager for startup/shutdown
@@ -91,8 +92,9 @@ async def dispatch(self, request: Request, call_next):
9192
app.include_router(users_router, prefix=API_PREFIX)
9293
app.include_router(search_v2_router, prefix=API_PREFIX)
9394

94-
# WebSocket endpoint (versioned)
95+
# WebSocket endpoints (versioned)
9596
app.add_api_websocket_route(f"{API_PREFIX}/ws/index/{{repo_id}}", websocket_index)
97+
app.add_api_websocket_route(f"{API_PREFIX}/ws/playground/{{job_id}}", websocket_playground_index)
9698

9799

98100
# ===== ERROR HANDLERS =====

backend/routes/ws_playground.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
WebSocket endpoint for real-time playground indexing progress.
3+
4+
This provides instant updates as files are indexed, giving users
5+
a smooth streaming experience instead of polling every 2 seconds.
6+
7+
Channel format: job:{job_id}:events
8+
Message types: started, progress, completed, error
9+
"""
10+
import json
11+
import asyncio
12+
from typing import Optional
13+
14+
from fastapi import WebSocket, WebSocketDisconnect
15+
16+
from dependencies import redis_client
17+
from services.observability import logger
18+
19+
20+
# How long to wait for first message before giving up
21+
INITIAL_TIMEOUT_SECONDS = 30
22+
23+
# How long between messages before assuming job is dead
24+
MESSAGE_TIMEOUT_SECONDS = 60
25+
26+
27+
async def websocket_playground_index(websocket: WebSocket, job_id: str):
28+
"""
29+
Stream indexing progress to client via WebSocket.
30+
31+
Subscribes to Redis pub/sub channel for this job and forwards
32+
all events to the connected client. Closes when job completes
33+
or fails, or if client disconnects.
34+
35+
No auth required - job_id is an unguessable UUID that acts as
36+
a bearer token. Only the session that created the job knows it.
37+
"""
38+
# Validate we have Redis (required for pub/sub)
39+
if not redis_client:
40+
logger.error("WebSocket failed - no Redis connection")
41+
await websocket.close(code=4500, reason="Service unavailable")
42+
return
43+
44+
# Validate job_id format (basic sanity check)
45+
if not job_id or len(job_id) < 10:
46+
await websocket.close(code=4400, reason="Invalid job ID")
47+
return
48+
49+
channel = f"job:{job_id}:events"
50+
51+
# Accept the WebSocket connection
52+
await websocket.accept()
53+
logger.info("WebSocket connected", job_id=job_id[:12], channel=channel)
54+
55+
# Set up Redis pub/sub
56+
pubsub = redis_client.pubsub()
57+
58+
try:
59+
# Subscribe to job's event channel
60+
await asyncio.to_thread(pubsub.subscribe, channel)
61+
logger.debug("Subscribed to channel", channel=channel)
62+
63+
# Send initial ack so client knows we're connected
64+
await websocket.send_json({
65+
"type": "connected",
66+
"job_id": job_id,
67+
"message": "Listening for indexing events"
68+
})
69+
70+
# Listen for messages
71+
while True:
72+
# Check for new message (non-blocking with timeout)
73+
message = await asyncio.to_thread(
74+
pubsub.get_message,
75+
ignore_subscribe_messages=True,
76+
timeout=MESSAGE_TIMEOUT_SECONDS
77+
)
78+
79+
if message is None:
80+
# Timeout - check if client is still connected
81+
try:
82+
await websocket.send_json({"type": "ping"})
83+
except Exception:
84+
logger.debug("Client disconnected during timeout", job_id=job_id[:12])
85+
break
86+
continue
87+
88+
if message["type"] != "message":
89+
continue
90+
91+
# Parse and forward the event
92+
try:
93+
event_data = json.loads(message["data"])
94+
await websocket.send_json(event_data)
95+
96+
# Close connection after terminal events
97+
event_type = event_data.get("type")
98+
if event_type in ("completed", "error"):
99+
logger.info(
100+
"Job finished, closing WebSocket",
101+
job_id=job_id[:12],
102+
event_type=event_type
103+
)
104+
break
105+
106+
except json.JSONDecodeError:
107+
logger.warning("Invalid JSON in pub/sub message", job_id=job_id[:12])
108+
continue
109+
except Exception as e:
110+
logger.error("Error forwarding message", error=str(e), job_id=job_id[:12])
111+
continue
112+
113+
except WebSocketDisconnect:
114+
logger.debug("WebSocket disconnected by client", job_id=job_id[:12])
115+
116+
except Exception as e:
117+
logger.error("WebSocket error", error=str(e), job_id=job_id[:12])
118+
try:
119+
await websocket.send_json({
120+
"type": "error",
121+
"message": "Internal server error"
122+
})
123+
except Exception:
124+
pass
125+
126+
finally:
127+
# Clean up pub/sub subscription
128+
try:
129+
await asyncio.to_thread(pubsub.unsubscribe, channel)
130+
await asyncio.to_thread(pubsub.close)
131+
except Exception:
132+
pass
133+
134+
# Close WebSocket if still open
135+
try:
136+
await websocket.close()
137+
except Exception:
138+
pass
139+
140+
logger.debug("WebSocket cleanup complete", job_id=job_id[:12])

backend/services/anonymous_indexer.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class AnonymousIndexingJob:
6060
"""
6161

6262
REDIS_PREFIX = "anon_job:"
63+
PUBSUB_PREFIX = "job:" # Channel: job:{job_id}:events
6364
JOB_TTL_SECONDS = 3600 # 1 hour for job metadata
6465
REPO_TTL_HOURS = 24 # 24 hours for indexed data
6566
TEMP_DIR = "/tmp/anon_repos"
@@ -71,6 +72,27 @@ def __init__(self, redis_client):
7172
# Ensure temp directory exists
7273
Path(self.TEMP_DIR).mkdir(parents=True, exist_ok=True)
7374

75+
def _get_channel(self, job_id: str) -> str:
76+
"""Get Redis pub/sub channel for job events."""
77+
return f"{self.PUBSUB_PREFIX}{job_id}:events"
78+
79+
def _publish_event(self, job_id: str, event: dict) -> None:
80+
"""
81+
Publish event to Redis pub/sub for real-time WebSocket updates.
82+
83+
Events are fire-and-forget - if no one is listening, that's fine.
84+
The job state in Redis is the source of truth for polling fallback.
85+
"""
86+
if not self.redis:
87+
return
88+
89+
try:
90+
channel = self._get_channel(job_id)
91+
self.redis.publish(channel, json.dumps(event))
92+
except Exception as e:
93+
# Don't fail the job if pub/sub fails - it's nice-to-have
94+
logger.warning("Failed to publish event", job_id=job_id, error=str(e))
95+
7496
@staticmethod
7597
def generate_job_id() -> str:
7698
"""Generate unique job ID."""
@@ -160,7 +182,7 @@ def update_status(
160182
error: Optional[str] = None,
161183
error_message: Optional[str] = None
162184
) -> bool:
163-
"""Update job status in Redis."""
185+
"""Update job status in Redis and publish event to WebSocket clients."""
164186
if not self.redis:
165187
return False
166188

@@ -185,6 +207,30 @@ def update_status(
185207
key = self._get_key(job_id)
186208
self.redis.setex(key, self.JOB_TTL_SECONDS, json.dumps(job))
187209

210+
# Publish status change event for WebSocket clients
211+
# Note: PROCESSING events are handled by update_progress() to avoid duplicates
212+
event = {"type": status.value, "job_id": job_id}
213+
214+
if status == JobStatus.QUEUED:
215+
event["message"] = "Job queued for processing"
216+
elif status == JobStatus.CLONING:
217+
event["message"] = "Cloning repository..."
218+
event["repo_name"] = job.get("repo_name")
219+
elif status == JobStatus.PROCESSING:
220+
# Skip publishing here - update_progress() sends granular progress events
221+
return True
222+
elif status == JobStatus.COMPLETED:
223+
event["message"] = "Indexing complete!"
224+
event["repo_id"] = repo_id
225+
if stats:
226+
event["stats"] = stats.to_dict()
227+
elif status == JobStatus.FAILED:
228+
event["message"] = error_message or "Indexing failed"
229+
event["error"] = error
230+
event["recoverable"] = error not in ("timeout", "clone_failed")
231+
232+
self._publish_event(job_id, event)
233+
188234
return True
189235

190236
def update_progress(
@@ -195,7 +241,26 @@ def update_progress(
195241
files_total: int,
196242
current_file: Optional[str] = None
197243
) -> bool:
198-
"""Update job progress (called during indexing)."""
244+
"""
245+
Update job progress during indexing.
246+
247+
Publishes real-time progress event for WebSocket clients,
248+
then updates the job state in Redis.
249+
"""
250+
# Publish progress event for real-time streaming
251+
# This is separate from status events - more granular
252+
percent = int((files_processed / files_total) * 100) if files_total > 0 else 0
253+
254+
self._publish_event(job_id, {
255+
"type": "progress",
256+
"files_processed": files_processed,
257+
"files_total": files_total,
258+
"functions_found": functions_found,
259+
"current_file": current_file,
260+
"percent": percent
261+
})
262+
263+
# Update job state in Redis (for polling fallback)
199264
progress = JobProgress(
200265
files_total=files_total,
201266
files_processed=files_processed,

0 commit comments

Comments
 (0)