diff --git a/backend/main.py b/backend/main.py index 5f18600..fb6db31 100644 --- a/backend/main.py +++ b/backend/main.py @@ -27,6 +27,7 @@ from routes.api_keys import router as api_keys_router from routes.users import router as users_router from routes.search_v2 import router as search_v2_router +from routes.ws_playground import websocket_playground_index # Lifespan context manager for startup/shutdown @@ -91,8 +92,9 @@ async def dispatch(self, request: Request, call_next): app.include_router(users_router, prefix=API_PREFIX) app.include_router(search_v2_router, prefix=API_PREFIX) -# WebSocket endpoint (versioned) +# WebSocket endpoints (versioned) app.add_api_websocket_route(f"{API_PREFIX}/ws/index/{{repo_id}}", websocket_index) +app.add_api_websocket_route(f"{API_PREFIX}/ws/playground/{{job_id}}", websocket_playground_index) # ===== ERROR HANDLERS ===== diff --git a/backend/routes/ws_playground.py b/backend/routes/ws_playground.py new file mode 100644 index 0000000..cc642d4 --- /dev/null +++ b/backend/routes/ws_playground.py @@ -0,0 +1,191 @@ +""" +WebSocket endpoint for real-time playground indexing progress. + +This provides instant updates as files are indexed, giving users +a smooth streaming experience instead of polling every 2 seconds. + +Channel format: job:{job_id}:events +Message types: connected, cloning, progress, completed, error +""" +import json +import asyncio +from typing import Optional + +from fastapi import WebSocket, WebSocketDisconnect + +from dependencies import redis_client +from services.observability import logger +from services.anonymous_indexer import AnonymousIndexingJob + + +# How long between messages before sending a ping +PING_INTERVAL_SECONDS = 30 + +# How long to wait for any activity before closing +IDLE_TIMEOUT_SECONDS = 120 + + +async def websocket_playground_index(websocket: WebSocket, job_id: str): + """ + Stream indexing progress to client via WebSocket. + + Subscribes to Redis pub/sub channel for this job and forwards + all events to the connected client. Closes when job completes + or fails, or if client disconnects. + + No auth required - job_id is an unguessable UUID that acts as + a bearer token. Only the session that created the job knows it. + """ + # Validate job_id format (basic sanity check) + if not job_id or len(job_id) < 10: + # Must accept before we can close with a reason + await websocket.accept() + await websocket.close(code=4400, reason="Invalid job ID") + return + + # Validate we have Redis (required for pub/sub) + if not redis_client: + logger.error("WebSocket failed - no Redis connection") + await websocket.accept() + await websocket.close(code=4500, reason="Service unavailable") + return + + # Check if job exists before subscribing + job_manager = AnonymousIndexingJob(redis_client) + job = job_manager.get_job(job_id) + + if not job: + await websocket.accept() + await websocket.close(code=4404, reason="Job not found") + return + + # Accept the WebSocket connection + await websocket.accept() + logger.info("WebSocket connected", job_id=job_id[:12]) + + # Handle race condition: job might already be complete + job_status = job.get("status") + if job_status == "completed": + await websocket.send_json({ + "type": "completed", + "job_id": job_id, + "repo_id": job.get("repo_id"), + "stats": job.get("stats"), + "message": "Indexing already complete" + }) + await websocket.close() + return + elif job_status == "failed": + await websocket.send_json({ + "type": "error", + "job_id": job_id, + "error": job.get("error"), + "message": job.get("error_message", "Indexing failed"), + "recoverable": False + }) + await websocket.close() + return + + channel = f"job:{job_id}:events" + pubsub = redis_client.pubsub() + + try: + # Subscribe to job's event channel + await asyncio.to_thread(pubsub.subscribe, channel) + logger.debug("Subscribed to channel", channel=channel) + + # Send initial ack with current state + await websocket.send_json({ + "type": "connected", + "job_id": job_id, + "current_status": job_status, + "message": "Listening for indexing events" + }) + + # Listen for messages + last_activity = asyncio.get_event_loop().time() + + while True: + current_time = asyncio.get_event_loop().time() + + # Check for idle timeout + if current_time - last_activity > IDLE_TIMEOUT_SECONDS: + logger.warning("WebSocket idle timeout", job_id=job_id[:12]) + await websocket.send_json({ + "type": "error", + "message": "Connection timed out - no activity" + }) + break + + # Check for new message (non-blocking with short timeout) + message = await asyncio.to_thread( + pubsub.get_message, + ignore_subscribe_messages=True, + timeout=PING_INTERVAL_SECONDS + ) + + if message is None: + # No message - send ping to keep connection alive + try: + await websocket.send_json({"type": "ping"}) + except Exception: + logger.debug("Client disconnected during ping", job_id=job_id[:12]) + break + continue + + if message["type"] != "message": + continue + + # Got a message - reset activity timer + last_activity = current_time + + # Parse and forward the event + try: + event_data = json.loads(message["data"]) + await websocket.send_json(event_data) + + # Close connection after terminal events + event_type = event_data.get("type") + if event_type in ("completed", "error"): + logger.info( + "Job finished, closing WebSocket", + job_id=job_id[:12], + event_type=event_type + ) + break + + except json.JSONDecodeError: + logger.warning("Invalid JSON in pub/sub message", job_id=job_id[:12]) + continue + except Exception as e: + logger.error("Error forwarding message", error=str(e), job_id=job_id[:12]) + continue + + except WebSocketDisconnect: + logger.debug("WebSocket disconnected by client", job_id=job_id[:12]) + + except Exception as e: + logger.error("WebSocket error", error=str(e), job_id=job_id[:12]) + try: + await websocket.send_json({ + "type": "error", + "message": "Internal server error" + }) + except Exception: + pass + + finally: + # Clean up pub/sub subscription + try: + await asyncio.to_thread(pubsub.unsubscribe, channel) + await asyncio.to_thread(pubsub.close) + except Exception: + pass + + # Close WebSocket if still open + try: + await websocket.close() + except Exception: + pass + + logger.debug("WebSocket cleanup complete", job_id=job_id[:12]) diff --git a/backend/scripts/manual_ws_test.py b/backend/scripts/manual_ws_test.py new file mode 100644 index 0000000..07c7874 --- /dev/null +++ b/backend/scripts/manual_ws_test.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +""" +MANUAL WebSocket E2E test for playground indexing. + +NOT run in CI - requires: + - Running backend server (uvicorn main:app) + - Redis running + - aiohttp installed (pip install aiohttp) + +This script: +1. Creates an indexing job via the REST API +2. Connects to the WebSocket endpoint +3. Listens for all events until completion/error +4. Reports what we received + +Usage: + cd backend + pip install aiohttp # if not installed + python3 scripts/manual_ws_test.py +""" +import asyncio +import aiohttp +import json +import sys +from datetime import datetime + +# Config +BASE_URL = "http://localhost:8000/api/v1" +WS_URL = "ws://localhost:8000/api/v1" +TEST_REPO = "https://github.com/pmndrs/zustand" # Small, fast to index + + +def log(msg: str, level: str = "INFO"): + """Print timestamped log message.""" + ts = datetime.now().strftime("%H:%M:%S.%f")[:-3] + icon = {"INFO": "â„šī¸", "OK": "✅", "ERR": "❌", "WS": "🔌", "EVENT": "📨"}.get(level, "â€ĸ") + print(f"[{ts}] {icon} {msg}") + + +async def create_indexing_job(session: aiohttp.ClientSession) -> dict: + """Create a new indexing job via REST API.""" + log("Creating indexing job for zustand...") + + async with session.post( + f"{BASE_URL}/playground/index", + json={"github_url": TEST_REPO} + ) as resp: + # 202 Accepted is the expected status for async job creation + if resp.status not in (200, 202): + text = await resp.text() + log(f"Failed to create job: {resp.status} - {text}", "ERR") + return None + + data = await resp.json() + job_id = data.get("job_id") + log(f"Job created: {job_id} (status: {resp.status})", "OK") + return data + + +async def listen_websocket(job_id: str) -> list: + """Connect to WebSocket and collect all events.""" + events = [] + ws_endpoint = f"{WS_URL}/ws/playground/{job_id}" + + log(f"Connecting to WebSocket: {ws_endpoint}", "WS") + + async with aiohttp.ClientSession() as session: + try: + async with session.ws_connect(ws_endpoint, timeout=120) as ws: + log("WebSocket connected!", "OK") + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + event = json.loads(msg.data) + events.append(event) + + event_type = event.get("type", "unknown") + + # Log based on event type + if event_type == "connected": + log(f"Server acknowledged connection", "EVENT") + elif event_type == "ping": + log("Received keepalive ping", "EVENT") + elif event_type == "cloning": + repo = event.get("repo_name", "?") + log(f"Cloning: {repo}", "EVENT") + elif event_type == "progress": + pct = event.get("percent", 0) + files = event.get("files_processed", 0) + total = event.get("files_total", 0) + current = event.get("current_file") or "" + funcs = event.get("functions_found", 0) + # Truncate long paths + if current and len(current) > 40: + current = "..." + current[-37:] + log(f"Progress: {pct}% ({files}/{total}) | {funcs} funcs | {current}", "EVENT") + elif event_type == "completed": + stats = event.get("stats", {}) + log(f"COMPLETED! Functions: {stats.get('functions_found', '?')}, Time: {stats.get('time_taken_seconds', '?')}s", "OK") + break + elif event_type == "error": + log(f"ERROR: {event.get('message', 'Unknown error')}", "ERR") + break + else: + log(f"Unknown event: {event_type}", "EVENT") + + elif msg.type == aiohttp.WSMsgType.ERROR: + log(f"WebSocket error: {ws.exception()}", "ERR") + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + log("WebSocket closed by server", "WS") + break + + except asyncio.TimeoutError: + log("WebSocket connection timed out", "ERR") + except Exception as e: + log(f"WebSocket error: {e}", "ERR") + + return events + + +async def main(): + """Run the end-to-end test.""" + print("\n" + "="*60) + print(" WebSocket E2E Test - Playground Indexing") + print("="*60 + "\n") + + async with aiohttp.ClientSession() as session: + # Step 1: Create job + job_data = await create_indexing_job(session) + if not job_data: + sys.exit(1) + + job_id = job_data.get("job_id") + if not job_id: + log("No job_id in response", "ERR") + sys.exit(1) + + # Step 2: Listen to WebSocket + print() + events = await listen_websocket(job_id) + + # Step 3: Summary + print("\n" + "="*60) + print(" Test Summary") + print("="*60) + + event_types = [e.get("type") for e in events] + print(f"\nTotal events received: {len(events)}") + print(f"Event types: {' → '.join(event_types)}") + + # Check expected flow + # Note: "cloning" may be skipped if repo was recently cloned + required = ["connected", "completed"] + has_required = all(t in event_types for t in required) + has_progress = "progress" in event_types + + print() + if has_required and has_progress: + log("TEST PASSED - Full event flow received!", "OK") + print() + return 0 + elif "error" in event_types: + log("TEST COMPLETED WITH ERROR - Error event received (may be expected)", "ERR") + print() + return 1 + else: + log(f"TEST INCOMPLETE - Missing events. Got: {event_types}", "ERR") + print() + return 1 + + +if __name__ == "__main__": + exit_code = asyncio.run(main()) + sys.exit(exit_code) diff --git a/backend/services/anonymous_indexer.py b/backend/services/anonymous_indexer.py index b8048c6..9ffacb8 100644 --- a/backend/services/anonymous_indexer.py +++ b/backend/services/anonymous_indexer.py @@ -60,6 +60,7 @@ class AnonymousIndexingJob: """ REDIS_PREFIX = "anon_job:" + PUBSUB_PREFIX = "job:" # Channel: job:{job_id}:events JOB_TTL_SECONDS = 3600 # 1 hour for job metadata REPO_TTL_HOURS = 24 # 24 hours for indexed data TEMP_DIR = "/tmp/anon_repos" @@ -71,6 +72,27 @@ def __init__(self, redis_client): # Ensure temp directory exists Path(self.TEMP_DIR).mkdir(parents=True, exist_ok=True) + def _get_channel(self, job_id: str) -> str: + """Get Redis pub/sub channel for job events.""" + return f"{self.PUBSUB_PREFIX}{job_id}:events" + + def _publish_event(self, job_id: str, event: dict) -> None: + """ + Publish event to Redis pub/sub for real-time WebSocket updates. + + Events are fire-and-forget - if no one is listening, that's fine. + The job state in Redis is the source of truth for polling fallback. + """ + if not self.redis: + return + + try: + channel = self._get_channel(job_id) + self.redis.publish(channel, json.dumps(event)) + except Exception as e: + # Don't fail the job if pub/sub fails - it's nice-to-have + logger.warning("Failed to publish event", job_id=job_id, error=str(e)) + @staticmethod def generate_job_id() -> str: """Generate unique job ID.""" @@ -160,7 +182,7 @@ def update_status( error: Optional[str] = None, error_message: Optional[str] = None ) -> bool: - """Update job status in Redis.""" + """Update job status in Redis and publish event to WebSocket clients.""" if not self.redis: return False @@ -185,6 +207,30 @@ def update_status( key = self._get_key(job_id) self.redis.setex(key, self.JOB_TTL_SECONDS, json.dumps(job)) + # Publish status change event for WebSocket clients + # Note: PROCESSING events are handled by update_progress() to avoid duplicates + event = {"type": status.value, "job_id": job_id} + + if status == JobStatus.QUEUED: + event["message"] = "Job queued for processing" + elif status == JobStatus.CLONING: + event["message"] = "Cloning repository..." + event["repo_name"] = job.get("repo_name") + elif status == JobStatus.PROCESSING: + # Skip publishing here - update_progress() sends granular progress events + return True + elif status == JobStatus.COMPLETED: + event["message"] = "Indexing complete!" + event["repo_id"] = repo_id + if stats: + event["stats"] = stats.to_dict() + elif status == JobStatus.FAILED: + event["message"] = error_message or "Indexing failed" + event["error"] = error + event["recoverable"] = error not in ("timeout", "clone_failed") + + self._publish_event(job_id, event) + return True def update_progress( @@ -195,7 +241,26 @@ def update_progress( files_total: int, current_file: Optional[str] = None ) -> bool: - """Update job progress (called during indexing).""" + """ + Update job progress during indexing. + + Publishes real-time progress event for WebSocket clients, + then updates the job state in Redis. + """ + # Publish progress event for real-time streaming + # This is separate from status events - more granular + percent = int((files_processed / files_total) * 100) if files_total > 0 else 0 + + self._publish_event(job_id, { + "type": "progress", + "files_processed": files_processed, + "files_total": files_total, + "functions_found": functions_found, + "current_file": current_file, + "percent": percent + }) + + # Update job state in Redis (for polling fallback) progress = JobProgress( files_total=files_total, files_processed=files_processed, diff --git a/backend/tests/test_ws_playground.py b/backend/tests/test_ws_playground.py new file mode 100644 index 0000000..c558981 --- /dev/null +++ b/backend/tests/test_ws_playground.py @@ -0,0 +1,306 @@ +""" +Tests for playground WebSocket endpoint. + +Tests the real-time indexing progress via WebSocket + Redis Pub/Sub. +""" +import pytest +import json +import asyncio +from unittest.mock import patch, MagicMock, AsyncMock +from fastapi.testclient import TestClient +from fastapi.websockets import WebSocket + + +class TestWebSocketPlayground: + """Test suite for playground WebSocket endpoint.""" + + def test_websocket_rejects_invalid_job_id(self): + """Short/invalid job IDs should be rejected.""" + from main import app + + client = TestClient(app) + + # Job ID too short - should accept then close with 4400 + with client.websocket_connect("/api/v1/ws/playground/abc") as ws: + # Connection will be closed immediately + pass + + def test_websocket_rejects_when_no_redis(self): + """Should close gracefully if Redis is unavailable.""" + from routes.ws_playground import websocket_playground_index + + mock_ws = AsyncMock(spec=WebSocket) + + with patch('routes.ws_playground.redis_client', None): + asyncio.run(websocket_playground_index(mock_ws, "idx_abc123def456")) + + # Should accept first, then close + mock_ws.accept.assert_called_once() + mock_ws.close.assert_called_once() + call_args = mock_ws.close.call_args + assert call_args[1]['code'] == 4500 + + def test_websocket_rejects_nonexistent_job(self): + """Should close with 4404 if job doesn't exist.""" + from routes.ws_playground import websocket_playground_index + + mock_ws = AsyncMock(spec=WebSocket) + mock_redis = MagicMock() + mock_redis.get.return_value = None # Job not found + + with patch('routes.ws_playground.redis_client', mock_redis): + asyncio.run(websocket_playground_index(mock_ws, "idx_nonexistent")) + + mock_ws.accept.assert_called_once() + mock_ws.close.assert_called_once() + call_args = mock_ws.close.call_args + assert call_args[1]['code'] == 4404 + + def test_websocket_handles_already_completed_job(self): + """If job is already complete, send completion and close.""" + from routes.ws_playground import websocket_playground_index + + mock_ws = AsyncMock(spec=WebSocket) + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "completed", + "repo_id": "anon_test123", + "stats": {"files_indexed": 100, "functions_found": 500} + }) + + with patch('routes.ws_playground.redis_client', mock_redis): + asyncio.run(websocket_playground_index(mock_ws, "idx_test123")) + + # Should send completed event immediately + mock_ws.send_json.assert_called_once() + sent_data = mock_ws.send_json.call_args[0][0] + assert sent_data["type"] == "completed" + assert sent_data["repo_id"] == "anon_test123" + mock_ws.close.assert_called_once() + + def test_websocket_handles_already_failed_job(self): + """If job already failed, send error and close.""" + from routes.ws_playground import websocket_playground_index + + mock_ws = AsyncMock(spec=WebSocket) + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "failed", + "error": "clone_failed", + "error_message": "Repository not found" + }) + + with patch('routes.ws_playground.redis_client', mock_redis): + asyncio.run(websocket_playground_index(mock_ws, "idx_test123")) + + mock_ws.send_json.assert_called_once() + sent_data = mock_ws.send_json.call_args[0][0] + assert sent_data["type"] == "error" + assert sent_data["error"] == "clone_failed" + mock_ws.close.assert_called_once() + + +class TestPubSubIntegration: + """Test Redis Pub/Sub event publishing.""" + + def test_publish_event_called_on_status_update(self): + """Verify events are published when status changes.""" + from services.anonymous_indexer import AnonymousIndexingJob, JobStatus + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "queued", + "repo_name": "flask" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + job_manager.update_status("idx_test123", JobStatus.CLONING) + + mock_redis.publish.assert_called() + call_args = mock_redis.publish.call_args + + channel = call_args[0][0] + event_data = json.loads(call_args[0][1]) + + assert channel == "job:idx_test123:events" + assert event_data["type"] == "cloning" + assert event_data["repo_name"] == "flask" + + def test_progress_event_published_with_file_info(self): + """Verify progress events include current file.""" + from services.anonymous_indexer import AnonymousIndexingJob, JobStatus + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "processing" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + + job_manager.update_progress( + job_id="idx_test123", + files_processed=25, + functions_found=150, + files_total=100, + current_file="src/flask/app.py" + ) + + # Find the progress event (first publish call) + progress_call = mock_redis.publish.call_args_list[0] + event_data = json.loads(progress_call[0][1]) + + assert event_data["type"] == "progress" + assert event_data["files_processed"] == 25 + assert event_data["files_total"] == 100 + assert event_data["current_file"] == "src/flask/app.py" + assert event_data["percent"] == 25 + + def test_completed_event_includes_stats(self): + """Verify completion event includes stats.""" + from services.anonymous_indexer import AnonymousIndexingJob, JobStatus, JobStats + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "processing" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + + stats = JobStats( + files_indexed=100, + functions_found=500, + time_taken_seconds=45.2 + ) + + job_manager.update_status( + "idx_test123", + JobStatus.COMPLETED, + stats=stats, + repo_id="anon_test123" + ) + + call_args = mock_redis.publish.call_args + event_data = json.loads(call_args[0][1]) + + assert event_data["type"] == "completed" + assert event_data["repo_id"] == "anon_test123" + assert event_data["stats"]["functions_found"] == 500 + + def test_processing_status_skips_duplicate_publish(self): + """PROCESSING status should not publish (handled by update_progress).""" + from services.anonymous_indexer import AnonymousIndexingJob, JobStatus, JobProgress + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "cloning" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + + progress = JobProgress(files_total=100, files_processed=0, functions_found=0) + job_manager.update_status("idx_test123", JobStatus.PROCESSING, progress=progress) + + # Should NOT have published (returns early for PROCESSING) + mock_redis.publish.assert_not_called() + + def test_failed_event_includes_recoverable_flag(self): + """Failed events should indicate if error is recoverable.""" + from services.anonymous_indexer import AnonymousIndexingJob, JobStatus + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "processing" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + + # Clone failures are not recoverable + job_manager.update_status( + "idx_test123", + JobStatus.FAILED, + error="clone_failed", + error_message="Repo not found" + ) + + call_args = mock_redis.publish.call_args + event_data = json.loads(call_args[0][1]) + + assert event_data["type"] == "failed" + assert event_data["recoverable"] == False + + +class TestEventFormats: + """Verify event message formats match frontend expectations.""" + + def test_event_types_are_strings(self): + """Event types should be string values, not enum objects.""" + from services.anonymous_indexer import AnonymousIndexingJob, JobStatus + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "queued" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + job_manager.update_status("idx_test123", JobStatus.CLONING) + + call_args = mock_redis.publish.call_args + event_data = json.loads(call_args[0][1]) + + assert isinstance(event_data["type"], str) + assert event_data["type"] == "cloning" + + def test_progress_percent_is_integer(self): + """Progress percent should be an integer 0-100.""" + from services.anonymous_indexer import AnonymousIndexingJob + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "processing" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + + job_manager.update_progress( + job_id="idx_test123", + files_processed=33, + functions_found=100, + files_total=100, + current_file="test.py" + ) + + progress_call = mock_redis.publish.call_args_list[0] + event_data = json.loads(progress_call[0][1]) + + assert isinstance(event_data["percent"], int) + assert event_data["percent"] == 33 + + def test_all_events_include_job_id(self): + """All events should include job_id for client correlation.""" + from services.anonymous_indexer import AnonymousIndexingJob, JobStatus + + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps({ + "job_id": "idx_test123", + "status": "queued" + }) + + job_manager = AnonymousIndexingJob(mock_redis) + + # Test different event types + job_manager.update_status("idx_test123", JobStatus.CLONING) + + call_args = mock_redis.publish.call_args + event_data = json.loads(call_args[0][1]) + + assert "job_id" in event_data + assert event_data["job_id"] == "idx_test123"