Skip to content

Commit d0f2033

Browse files
authored
Merge pull request #150 from DevanshuNEU/feature/149-websocket-indexing-progress
feat(backend): WebSocket endpoint for real-time indexing progress (#149)
2 parents a35a035 + 8e4a033 commit d0f2033

5 files changed

Lines changed: 742 additions & 3 deletions

File tree

backend/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from routes.api_keys import router as api_keys_router
2828
from routes.users import router as users_router
2929
from routes.search_v2 import router as search_v2_router
30+
from routes.ws_playground import websocket_playground_index
3031

3132

3233
# Lifespan context manager for startup/shutdown
@@ -91,8 +92,9 @@ async def dispatch(self, request: Request, call_next):
9192
app.include_router(users_router, prefix=API_PREFIX)
9293
app.include_router(search_v2_router, prefix=API_PREFIX)
9394

94-
# WebSocket endpoint (versioned)
95+
# WebSocket endpoints (versioned)
9596
app.add_api_websocket_route(f"{API_PREFIX}/ws/index/{{repo_id}}", websocket_index)
97+
app.add_api_websocket_route(f"{API_PREFIX}/ws/playground/{{job_id}}", websocket_playground_index)
9698

9799

98100
# ===== ERROR HANDLERS =====

backend/routes/ws_playground.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
WebSocket endpoint for real-time playground indexing progress.
3+
4+
This provides instant updates as files are indexed, giving users
5+
a smooth streaming experience instead of polling every 2 seconds.
6+
7+
Channel format: job:{job_id}:events
8+
Message types: connected, cloning, progress, completed, error
9+
"""
10+
import json
11+
import asyncio
12+
from typing import Optional
13+
14+
from fastapi import WebSocket, WebSocketDisconnect
15+
16+
from dependencies import redis_client
17+
from services.observability import logger
18+
from services.anonymous_indexer import AnonymousIndexingJob
19+
20+
21+
# How long between messages before sending a ping
22+
PING_INTERVAL_SECONDS = 30
23+
24+
# How long to wait for any activity before closing
25+
IDLE_TIMEOUT_SECONDS = 120
26+
27+
28+
async def websocket_playground_index(websocket: WebSocket, job_id: str):
29+
"""
30+
Stream indexing progress to client via WebSocket.
31+
32+
Subscribes to Redis pub/sub channel for this job and forwards
33+
all events to the connected client. Closes when job completes
34+
or fails, or if client disconnects.
35+
36+
No auth required - job_id is an unguessable UUID that acts as
37+
a bearer token. Only the session that created the job knows it.
38+
"""
39+
# Validate job_id format (basic sanity check)
40+
if not job_id or len(job_id) < 10:
41+
# Must accept before we can close with a reason
42+
await websocket.accept()
43+
await websocket.close(code=4400, reason="Invalid job ID")
44+
return
45+
46+
# Validate we have Redis (required for pub/sub)
47+
if not redis_client:
48+
logger.error("WebSocket failed - no Redis connection")
49+
await websocket.accept()
50+
await websocket.close(code=4500, reason="Service unavailable")
51+
return
52+
53+
# Check if job exists before subscribing
54+
job_manager = AnonymousIndexingJob(redis_client)
55+
job = job_manager.get_job(job_id)
56+
57+
if not job:
58+
await websocket.accept()
59+
await websocket.close(code=4404, reason="Job not found")
60+
return
61+
62+
# Accept the WebSocket connection
63+
await websocket.accept()
64+
logger.info("WebSocket connected", job_id=job_id[:12])
65+
66+
# Handle race condition: job might already be complete
67+
job_status = job.get("status")
68+
if job_status == "completed":
69+
await websocket.send_json({
70+
"type": "completed",
71+
"job_id": job_id,
72+
"repo_id": job.get("repo_id"),
73+
"stats": job.get("stats"),
74+
"message": "Indexing already complete"
75+
})
76+
await websocket.close()
77+
return
78+
elif job_status == "failed":
79+
await websocket.send_json({
80+
"type": "error",
81+
"job_id": job_id,
82+
"error": job.get("error"),
83+
"message": job.get("error_message", "Indexing failed"),
84+
"recoverable": False
85+
})
86+
await websocket.close()
87+
return
88+
89+
channel = f"job:{job_id}:events"
90+
pubsub = redis_client.pubsub()
91+
92+
try:
93+
# Subscribe to job's event channel
94+
await asyncio.to_thread(pubsub.subscribe, channel)
95+
logger.debug("Subscribed to channel", channel=channel)
96+
97+
# Send initial ack with current state
98+
await websocket.send_json({
99+
"type": "connected",
100+
"job_id": job_id,
101+
"current_status": job_status,
102+
"message": "Listening for indexing events"
103+
})
104+
105+
# Listen for messages
106+
last_activity = asyncio.get_event_loop().time()
107+
108+
while True:
109+
current_time = asyncio.get_event_loop().time()
110+
111+
# Check for idle timeout
112+
if current_time - last_activity > IDLE_TIMEOUT_SECONDS:
113+
logger.warning("WebSocket idle timeout", job_id=job_id[:12])
114+
await websocket.send_json({
115+
"type": "error",
116+
"message": "Connection timed out - no activity"
117+
})
118+
break
119+
120+
# Check for new message (non-blocking with short timeout)
121+
message = await asyncio.to_thread(
122+
pubsub.get_message,
123+
ignore_subscribe_messages=True,
124+
timeout=PING_INTERVAL_SECONDS
125+
)
126+
127+
if message is None:
128+
# No message - send ping to keep connection alive
129+
try:
130+
await websocket.send_json({"type": "ping"})
131+
except Exception:
132+
logger.debug("Client disconnected during ping", job_id=job_id[:12])
133+
break
134+
continue
135+
136+
if message["type"] != "message":
137+
continue
138+
139+
# Got a message - reset activity timer
140+
last_activity = current_time
141+
142+
# Parse and forward the event
143+
try:
144+
event_data = json.loads(message["data"])
145+
await websocket.send_json(event_data)
146+
147+
# Close connection after terminal events
148+
event_type = event_data.get("type")
149+
if event_type in ("completed", "error"):
150+
logger.info(
151+
"Job finished, closing WebSocket",
152+
job_id=job_id[:12],
153+
event_type=event_type
154+
)
155+
break
156+
157+
except json.JSONDecodeError:
158+
logger.warning("Invalid JSON in pub/sub message", job_id=job_id[:12])
159+
continue
160+
except Exception as e:
161+
logger.error("Error forwarding message", error=str(e), job_id=job_id[:12])
162+
continue
163+
164+
except WebSocketDisconnect:
165+
logger.debug("WebSocket disconnected by client", job_id=job_id[:12])
166+
167+
except Exception as e:
168+
logger.error("WebSocket error", error=str(e), job_id=job_id[:12])
169+
try:
170+
await websocket.send_json({
171+
"type": "error",
172+
"message": "Internal server error"
173+
})
174+
except Exception:
175+
pass
176+
177+
finally:
178+
# Clean up pub/sub subscription
179+
try:
180+
await asyncio.to_thread(pubsub.unsubscribe, channel)
181+
await asyncio.to_thread(pubsub.close)
182+
except Exception:
183+
pass
184+
185+
# Close WebSocket if still open
186+
try:
187+
await websocket.close()
188+
except Exception:
189+
pass
190+
191+
logger.debug("WebSocket cleanup complete", job_id=job_id[:12])

backend/scripts/manual_ws_test.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#!/usr/bin/env python3
2+
"""
3+
MANUAL WebSocket E2E test for playground indexing.
4+
5+
NOT run in CI - requires:
6+
- Running backend server (uvicorn main:app)
7+
- Redis running
8+
- aiohttp installed (pip install aiohttp)
9+
10+
This script:
11+
1. Creates an indexing job via the REST API
12+
2. Connects to the WebSocket endpoint
13+
3. Listens for all events until completion/error
14+
4. Reports what we received
15+
16+
Usage:
17+
cd backend
18+
pip install aiohttp # if not installed
19+
python3 scripts/manual_ws_test.py
20+
"""
21+
import asyncio
22+
import aiohttp
23+
import json
24+
import sys
25+
from datetime import datetime
26+
27+
# Config
28+
BASE_URL = "http://localhost:8000/api/v1"
29+
WS_URL = "ws://localhost:8000/api/v1"
30+
TEST_REPO = "https://github.com/pmndrs/zustand" # Small, fast to index
31+
32+
33+
def log(msg: str, level: str = "INFO"):
34+
"""Print timestamped log message."""
35+
ts = datetime.now().strftime("%H:%M:%S.%f")[:-3]
36+
icon = {"INFO": "ℹ️", "OK": "✅", "ERR": "❌", "WS": "🔌", "EVENT": "📨"}.get(level, "•")
37+
print(f"[{ts}] {icon} {msg}")
38+
39+
40+
async def create_indexing_job(session: aiohttp.ClientSession) -> dict:
41+
"""Create a new indexing job via REST API."""
42+
log("Creating indexing job for zustand...")
43+
44+
async with session.post(
45+
f"{BASE_URL}/playground/index",
46+
json={"github_url": TEST_REPO}
47+
) as resp:
48+
# 202 Accepted is the expected status for async job creation
49+
if resp.status not in (200, 202):
50+
text = await resp.text()
51+
log(f"Failed to create job: {resp.status} - {text}", "ERR")
52+
return None
53+
54+
data = await resp.json()
55+
job_id = data.get("job_id")
56+
log(f"Job created: {job_id} (status: {resp.status})", "OK")
57+
return data
58+
59+
60+
async def listen_websocket(job_id: str) -> list:
61+
"""Connect to WebSocket and collect all events."""
62+
events = []
63+
ws_endpoint = f"{WS_URL}/ws/playground/{job_id}"
64+
65+
log(f"Connecting to WebSocket: {ws_endpoint}", "WS")
66+
67+
async with aiohttp.ClientSession() as session:
68+
try:
69+
async with session.ws_connect(ws_endpoint, timeout=120) as ws:
70+
log("WebSocket connected!", "OK")
71+
72+
async for msg in ws:
73+
if msg.type == aiohttp.WSMsgType.TEXT:
74+
event = json.loads(msg.data)
75+
events.append(event)
76+
77+
event_type = event.get("type", "unknown")
78+
79+
# Log based on event type
80+
if event_type == "connected":
81+
log(f"Server acknowledged connection", "EVENT")
82+
elif event_type == "ping":
83+
log("Received keepalive ping", "EVENT")
84+
elif event_type == "cloning":
85+
repo = event.get("repo_name", "?")
86+
log(f"Cloning: {repo}", "EVENT")
87+
elif event_type == "progress":
88+
pct = event.get("percent", 0)
89+
files = event.get("files_processed", 0)
90+
total = event.get("files_total", 0)
91+
current = event.get("current_file") or ""
92+
funcs = event.get("functions_found", 0)
93+
# Truncate long paths
94+
if current and len(current) > 40:
95+
current = "..." + current[-37:]
96+
log(f"Progress: {pct}% ({files}/{total}) | {funcs} funcs | {current}", "EVENT")
97+
elif event_type == "completed":
98+
stats = event.get("stats", {})
99+
log(f"COMPLETED! Functions: {stats.get('functions_found', '?')}, Time: {stats.get('time_taken_seconds', '?')}s", "OK")
100+
break
101+
elif event_type == "error":
102+
log(f"ERROR: {event.get('message', 'Unknown error')}", "ERR")
103+
break
104+
else:
105+
log(f"Unknown event: {event_type}", "EVENT")
106+
107+
elif msg.type == aiohttp.WSMsgType.ERROR:
108+
log(f"WebSocket error: {ws.exception()}", "ERR")
109+
break
110+
elif msg.type == aiohttp.WSMsgType.CLOSED:
111+
log("WebSocket closed by server", "WS")
112+
break
113+
114+
except asyncio.TimeoutError:
115+
log("WebSocket connection timed out", "ERR")
116+
except Exception as e:
117+
log(f"WebSocket error: {e}", "ERR")
118+
119+
return events
120+
121+
122+
async def main():
123+
"""Run the end-to-end test."""
124+
print("\n" + "="*60)
125+
print(" WebSocket E2E Test - Playground Indexing")
126+
print("="*60 + "\n")
127+
128+
async with aiohttp.ClientSession() as session:
129+
# Step 1: Create job
130+
job_data = await create_indexing_job(session)
131+
if not job_data:
132+
sys.exit(1)
133+
134+
job_id = job_data.get("job_id")
135+
if not job_id:
136+
log("No job_id in response", "ERR")
137+
sys.exit(1)
138+
139+
# Step 2: Listen to WebSocket
140+
print()
141+
events = await listen_websocket(job_id)
142+
143+
# Step 3: Summary
144+
print("\n" + "="*60)
145+
print(" Test Summary")
146+
print("="*60)
147+
148+
event_types = [e.get("type") for e in events]
149+
print(f"\nTotal events received: {len(events)}")
150+
print(f"Event types: {' → '.join(event_types)}")
151+
152+
# Check expected flow
153+
# Note: "cloning" may be skipped if repo was recently cloned
154+
required = ["connected", "completed"]
155+
has_required = all(t in event_types for t in required)
156+
has_progress = "progress" in event_types
157+
158+
print()
159+
if has_required and has_progress:
160+
log("TEST PASSED - Full event flow received!", "OK")
161+
print()
162+
return 0
163+
elif "error" in event_types:
164+
log("TEST COMPLETED WITH ERROR - Error event received (may be expected)", "ERR")
165+
print()
166+
return 1
167+
else:
168+
log(f"TEST INCOMPLETE - Missing events. Got: {event_types}", "ERR")
169+
print()
170+
return 1
171+
172+
173+
if __name__ == "__main__":
174+
exit_code = asyncio.run(main())
175+
sys.exit(exit_code)

0 commit comments

Comments
 (0)