Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from routes.api_keys import router as api_keys_router
from routes.users import router as users_router
from routes.search_v2 import router as search_v2_router
from routes.ws_playground import websocket_playground_index


# Lifespan context manager for startup/shutdown
Expand Down Expand Up @@ -91,8 +92,9 @@ async def dispatch(self, request: Request, call_next):
app.include_router(users_router, prefix=API_PREFIX)
app.include_router(search_v2_router, prefix=API_PREFIX)

# WebSocket endpoint (versioned)
# WebSocket endpoints (versioned)
app.add_api_websocket_route(f"{API_PREFIX}/ws/index/{{repo_id}}", websocket_index)
app.add_api_websocket_route(f"{API_PREFIX}/ws/playground/{{job_id}}", websocket_playground_index)


# ===== ERROR HANDLERS =====
Expand Down
191 changes: 191 additions & 0 deletions backend/routes/ws_playground.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
WebSocket endpoint for real-time playground indexing progress.

This provides instant updates as files are indexed, giving users
a smooth streaming experience instead of polling every 2 seconds.

Channel format: job:{job_id}:events
Message types: connected, cloning, progress, completed, error
"""
import json
import asyncio
from typing import Optional

from fastapi import WebSocket, WebSocketDisconnect

from dependencies import redis_client
from services.observability import logger
from services.anonymous_indexer import AnonymousIndexingJob


# How long between messages before sending a ping
PING_INTERVAL_SECONDS = 30

# How long to wait for any activity before closing
IDLE_TIMEOUT_SECONDS = 120


async def websocket_playground_index(websocket: WebSocket, job_id: str):
"""
Stream indexing progress to client via WebSocket.

Subscribes to Redis pub/sub channel for this job and forwards
all events to the connected client. Closes when job completes
or fails, or if client disconnects.

No auth required - job_id is an unguessable UUID that acts as
a bearer token. Only the session that created the job knows it.
"""
# Validate job_id format (basic sanity check)
if not job_id or len(job_id) < 10:
# Must accept before we can close with a reason
await websocket.accept()
await websocket.close(code=4400, reason="Invalid job ID")
return

# Validate we have Redis (required for pub/sub)
if not redis_client:
logger.error("WebSocket failed - no Redis connection")
await websocket.accept()
await websocket.close(code=4500, reason="Service unavailable")
return

# Check if job exists before subscribing
job_manager = AnonymousIndexingJob(redis_client)
job = job_manager.get_job(job_id)

if not job:
await websocket.accept()
await websocket.close(code=4404, reason="Job not found")
return

# Accept the WebSocket connection
await websocket.accept()
logger.info("WebSocket connected", job_id=job_id[:12])

# Handle race condition: job might already be complete
job_status = job.get("status")
if job_status == "completed":
await websocket.send_json({
"type": "completed",
"job_id": job_id,
"repo_id": job.get("repo_id"),
"stats": job.get("stats"),
"message": "Indexing already complete"
})
await websocket.close()
return
elif job_status == "failed":
await websocket.send_json({
"type": "error",
"job_id": job_id,
"error": job.get("error"),
"message": job.get("error_message", "Indexing failed"),
"recoverable": False
})
await websocket.close()
return

channel = f"job:{job_id}:events"
pubsub = redis_client.pubsub()

try:
# Subscribe to job's event channel
await asyncio.to_thread(pubsub.subscribe, channel)
logger.debug("Subscribed to channel", channel=channel)

# Send initial ack with current state
await websocket.send_json({
"type": "connected",
"job_id": job_id,
"current_status": job_status,
"message": "Listening for indexing events"
})

# Listen for messages
last_activity = asyncio.get_event_loop().time()

while True:
current_time = asyncio.get_event_loop().time()

# Check for idle timeout
if current_time - last_activity > IDLE_TIMEOUT_SECONDS:
logger.warning("WebSocket idle timeout", job_id=job_id[:12])
await websocket.send_json({
"type": "error",
"message": "Connection timed out - no activity"
})
break

# Check for new message (non-blocking with short timeout)
message = await asyncio.to_thread(
pubsub.get_message,
ignore_subscribe_messages=True,
timeout=PING_INTERVAL_SECONDS
)

if message is None:
# No message - send ping to keep connection alive
try:
await websocket.send_json({"type": "ping"})
except Exception:
logger.debug("Client disconnected during ping", job_id=job_id[:12])
break
continue

if message["type"] != "message":
continue

# Got a message - reset activity timer
last_activity = current_time

# Parse and forward the event
try:
event_data = json.loads(message["data"])
await websocket.send_json(event_data)

# Close connection after terminal events
event_type = event_data.get("type")
if event_type in ("completed", "error"):
logger.info(
"Job finished, closing WebSocket",
job_id=job_id[:12],
event_type=event_type
)
break

except json.JSONDecodeError:
logger.warning("Invalid JSON in pub/sub message", job_id=job_id[:12])
continue
except Exception as e:
logger.error("Error forwarding message", error=str(e), job_id=job_id[:12])
continue

except WebSocketDisconnect:
logger.debug("WebSocket disconnected by client", job_id=job_id[:12])

except Exception as e:
logger.error("WebSocket error", error=str(e), job_id=job_id[:12])
try:
await websocket.send_json({
"type": "error",
"message": "Internal server error"
})
except Exception:
pass

finally:
# Clean up pub/sub subscription
try:
await asyncio.to_thread(pubsub.unsubscribe, channel)
await asyncio.to_thread(pubsub.close)
except Exception:
pass

# Close WebSocket if still open
try:
await websocket.close()
except Exception:
pass

logger.debug("WebSocket cleanup complete", job_id=job_id[:12])
175 changes: 175 additions & 0 deletions backend/scripts/manual_ws_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python3
"""
MANUAL WebSocket E2E test for playground indexing.

NOT run in CI - requires:
- Running backend server (uvicorn main:app)
- Redis running
- aiohttp installed (pip install aiohttp)

This script:
1. Creates an indexing job via the REST API
2. Connects to the WebSocket endpoint
3. Listens for all events until completion/error
4. Reports what we received

Usage:
cd backend
pip install aiohttp # if not installed
python3 scripts/manual_ws_test.py
"""
import asyncio
import aiohttp
import json
import sys
from datetime import datetime

# Config
BASE_URL = "http://localhost:8000/api/v1"
WS_URL = "ws://localhost:8000/api/v1"
TEST_REPO = "https://github.com/pmndrs/zustand" # Small, fast to index


def log(msg: str, level: str = "INFO"):
"""Print timestamped log message."""
ts = datetime.now().strftime("%H:%M:%S.%f")[:-3]
icon = {"INFO": "ℹ️", "OK": "✅", "ERR": "❌", "WS": "🔌", "EVENT": "📨"}.get(level, "•")
print(f"[{ts}] {icon} {msg}")


async def create_indexing_job(session: aiohttp.ClientSession) -> dict:
"""Create a new indexing job via REST API."""
log("Creating indexing job for zustand...")

async with session.post(
f"{BASE_URL}/playground/index",
json={"github_url": TEST_REPO}
) as resp:
# 202 Accepted is the expected status for async job creation
if resp.status not in (200, 202):
text = await resp.text()
log(f"Failed to create job: {resp.status} - {text}", "ERR")
return None

data = await resp.json()
job_id = data.get("job_id")
log(f"Job created: {job_id} (status: {resp.status})", "OK")
return data


async def listen_websocket(job_id: str) -> list:
"""Connect to WebSocket and collect all events."""
events = []
ws_endpoint = f"{WS_URL}/ws/playground/{job_id}"

log(f"Connecting to WebSocket: {ws_endpoint}", "WS")

async with aiohttp.ClientSession() as session:
try:
async with session.ws_connect(ws_endpoint, timeout=120) as ws:
log("WebSocket connected!", "OK")

async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
event = json.loads(msg.data)
events.append(event)

event_type = event.get("type", "unknown")

# Log based on event type
if event_type == "connected":
log(f"Server acknowledged connection", "EVENT")
elif event_type == "ping":
log("Received keepalive ping", "EVENT")
elif event_type == "cloning":
repo = event.get("repo_name", "?")
log(f"Cloning: {repo}", "EVENT")
elif event_type == "progress":
pct = event.get("percent", 0)
files = event.get("files_processed", 0)
total = event.get("files_total", 0)
current = event.get("current_file") or ""
funcs = event.get("functions_found", 0)
# Truncate long paths
if current and len(current) > 40:
current = "..." + current[-37:]
log(f"Progress: {pct}% ({files}/{total}) | {funcs} funcs | {current}", "EVENT")
elif event_type == "completed":
stats = event.get("stats", {})
log(f"COMPLETED! Functions: {stats.get('functions_found', '?')}, Time: {stats.get('time_taken_seconds', '?')}s", "OK")
break
elif event_type == "error":
log(f"ERROR: {event.get('message', 'Unknown error')}", "ERR")
break
else:
log(f"Unknown event: {event_type}", "EVENT")

elif msg.type == aiohttp.WSMsgType.ERROR:
log(f"WebSocket error: {ws.exception()}", "ERR")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
log("WebSocket closed by server", "WS")
break

except asyncio.TimeoutError:
log("WebSocket connection timed out", "ERR")
except Exception as e:
log(f"WebSocket error: {e}", "ERR")

return events


async def main():
"""Run the end-to-end test."""
print("\n" + "="*60)
print(" WebSocket E2E Test - Playground Indexing")
print("="*60 + "\n")

async with aiohttp.ClientSession() as session:
# Step 1: Create job
job_data = await create_indexing_job(session)
if not job_data:
sys.exit(1)

job_id = job_data.get("job_id")
if not job_id:
log("No job_id in response", "ERR")
sys.exit(1)

# Step 2: Listen to WebSocket
print()
events = await listen_websocket(job_id)

# Step 3: Summary
print("\n" + "="*60)
print(" Test Summary")
print("="*60)

event_types = [e.get("type") for e in events]
print(f"\nTotal events received: {len(events)}")
print(f"Event types: {' → '.join(event_types)}")

# Check expected flow
# Note: "cloning" may be skipped if repo was recently cloned
required = ["connected", "completed"]
has_required = all(t in event_types for t in required)
has_progress = "progress" in event_types

print()
if has_required and has_progress:
log("TEST PASSED - Full event flow received!", "OK")
print()
return 0
elif "error" in event_types:
log("TEST COMPLETED WITH ERROR - Error event received (may be expected)", "ERR")
print()
return 1
else:
log(f"TEST INCOMPLETE - Missing events. Got: {event_types}", "ERR")
print()
return 1


if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)
Loading