diff --git a/backend/routes/playground.py b/backend/routes/playground.py index 2080f57..8ed7f40 100644 --- a/backend/routes/playground.py +++ b/backend/routes/playground.py @@ -36,7 +36,7 @@ class PlaygroundSearchRequest(BaseModel): async def load_demo_repos(): """Load pre-indexed demo repos. Called from main.py on startup.""" - global DEMO_REPO_IDS + # Note: We mutate DEMO_REPO_IDS dict, no need for 'global' statement try: repos = repo_manager.list_repos() for repo in repos: @@ -89,15 +89,15 @@ def _get_limiter() -> PlaygroundLimiter: async def get_playground_limits(req: Request): """ Get current rate limit status for this user. - + Frontend should call this on page load to show accurate remaining count. """ session_token = _get_session_token(req) client_ip = _get_client_ip(req) - + limiter = _get_limiter() result = limiter.check_limit(session_token, client_ip) - + return { "remaining": result.remaining, "limit": result.limit, @@ -106,24 +106,92 @@ async def get_playground_limits(req: Request): } +@router.get("/session") +async def get_session_info(req: Request, response: Response): + """ + Get current session state including indexed repo info. + + Returns complete session data for frontend state management. + Creates a new session if none exists. + + Response schema (see issue #127): + { + "session_id": "pg_abc123...", + "created_at": "2025-12-24T10:00:00Z", + "expires_at": "2025-12-25T10:00:00Z", + "indexed_repo": { + "repo_id": "repo_abc123", + "github_url": "https://github.com/user/repo", + "name": "repo", + "indexed_at": "2025-12-24T10:05:00Z", + "expires_at": "2025-12-25T10:05:00Z", + "file_count": 198 + }, + "searches": { + "used": 12, + "limit": 50, + "remaining": 38 + } + } + """ + session_token = _get_session_token(req) + limiter = _get_limiter() + + # Check if Redis is available + if not redis_client: + logger.error("Redis unavailable for session endpoint") + raise HTTPException( + status_code=503, + detail={ + "message": "Service temporarily unavailable", + "retry_after": 30, + } + ) + + # Get existing session data + session_data = limiter.get_session_data(session_token) + + # If no session exists, create one + if session_data.session_id is None: + new_token = limiter._generate_session_token() + + if limiter.create_session(new_token): + _set_session_cookie(response, new_token) + session_data = limiter.get_session_data(new_token) + logger.info("Created new session via /session endpoint", + session_token=new_token[:8]) + else: + # Failed to create session (Redis issue) + raise HTTPException( + status_code=503, + detail={ + "message": "Failed to create session", + "retry_after": 30, + } + ) + + # Return formatted response + return session_data.to_response(limit=limiter.SESSION_LIMIT_PER_DAY) + + @router.post("/search") async def playground_search( - request: PlaygroundSearchRequest, + request: PlaygroundSearchRequest, req: Request, response: Response ): """ Public playground search - rate limited by session/IP. - + Sets httpOnly cookie on first request to track device. """ session_token = _get_session_token(req) client_ip = _get_client_ip(req) - + # Rate limit check AND record limiter = _get_limiter() limit_result = limiter.check_and_record(session_token, client_ip) - + if not limit_result.allowed: raise HTTPException( status_code=429, @@ -134,16 +202,16 @@ async def playground_search( "resets_at": limit_result.resets_at.isoformat(), } ) - + # Set session cookie if new token was created if limit_result.session_token: _set_session_cookie(response, limit_result.session_token) - + # Validate query valid_query, query_error = InputValidator.validate_search_query(request.query) if not valid_query: raise HTTPException(status_code=400, detail=f"Invalid query: {query_error}") - + # Get demo repo ID repo_id = DEMO_REPO_IDS.get(request.demo_repo) if not repo_id: @@ -156,12 +224,12 @@ async def playground_search( status_code=404, detail=f"Demo repo '{request.demo_repo}' not available" ) - + start_time = time.time() - + try: sanitized_query = InputValidator.sanitize_string(request.query, max_length=200) - + # Check cache cached_results = cache.get_search_results(sanitized_query, repo_id) if cached_results: @@ -172,7 +240,7 @@ async def playground_search( "remaining_searches": limit_result.remaining, "limit": limit_result.limit, } - + # Search results = await indexer.semantic_search( query=sanitized_query, @@ -181,12 +249,12 @@ async def playground_search( use_query_expansion=True, use_reranking=True ) - + # Cache results cache.set_search_results(sanitized_query, repo_id, results, ttl=3600) - + search_time = int((time.time() - start_time) * 1000) - + return { "results": results, "count": len(results), @@ -207,9 +275,24 @@ async def list_playground_repos(): """List available demo repositories.""" return { "repos": [ - {"id": "flask", "name": "Flask", "description": "Python web framework", "available": "flask" in DEMO_REPO_IDS}, - {"id": "fastapi", "name": "FastAPI", "description": "Modern Python API", "available": "fastapi" in DEMO_REPO_IDS}, - {"id": "express", "name": "Express", "description": "Node.js framework", "available": "express" in DEMO_REPO_IDS}, + { + "id": "flask", + "name": "Flask", + "description": "Python web framework", + "available": "flask" in DEMO_REPO_IDS + }, + { + "id": "fastapi", + "name": "FastAPI", + "description": "Modern Python API", + "available": "fastapi" in DEMO_REPO_IDS + }, + { + "id": "express", + "name": "Express", + "description": "Node.js framework", + "available": "express" in DEMO_REPO_IDS + }, ] } diff --git a/backend/services/playground_limiter.py b/backend/services/playground_limiter.py index 59847a4..91f595d 100644 --- a/backend/services/playground_limiter.py +++ b/backend/services/playground_limiter.py @@ -1,34 +1,44 @@ """ -Playground Rate Limiter -Redis-backed rate limiting for anonymous playground searches. +Playground Rate Limiter & Session Manager +Redis-backed rate limiting and session management for anonymous playground. Design: - Layer 1: Session token (httpOnly cookie) - 50 searches/day per device - Layer 2: IP-based fallback - 100 searches/day (for shared IPs) - Layer 3: Global circuit breaker - 10,000 searches/hour (cost protection) -Part of #93 implementation. +Session Data (Redis Hash): +- searches_used: Number of searches performed +- created_at: Session creation timestamp +- indexed_repo: JSON blob with indexed repo details (optional) + +Part of #93 (rate limiting) and #127 (session management) implementation. """ +import json import secrets import hashlib -from datetime import datetime, timezone -from typing import Optional, Tuple +from datetime import datetime, timezone, timedelta +from typing import Optional, Tuple, Dict, Any from dataclasses import dataclass -from services.observability import logger +from services.observability import logger, metrics, track_time from services.sentry import capture_exception +# ============================================================================= +# DATA CLASSES +# ============================================================================= + @dataclass class PlaygroundLimitResult: - """Result of a rate limit check""" + """Result of a rate limit check.""" allowed: bool remaining: int limit: int resets_at: datetime - reason: Optional[str] = None # Why blocked (if not allowed) - session_token: Optional[str] = None # New token if created - + reason: Optional[str] = None + session_token: Optional[str] = None + def to_dict(self) -> dict: return { "allowed": self.allowed, @@ -39,96 +49,451 @@ def to_dict(self) -> dict: } +@dataclass +class IndexedRepoData: + """ + Data about an indexed repository in a session. + + Stored as JSON in Redis hash field 'indexed_repo'. + Used by #125 (indexing) and #128 (search) endpoints. + """ + repo_id: str + github_url: str + name: str + file_count: int + indexed_at: str # ISO format + expires_at: str # ISO format + + def to_dict(self) -> dict: + return { + "repo_id": self.repo_id, + "github_url": self.github_url, + "name": self.name, + "file_count": self.file_count, + "indexed_at": self.indexed_at, + "expires_at": self.expires_at, + } + + @classmethod + def from_dict(cls, data: dict) -> "IndexedRepoData": + """Create from dictionary (parsed from Redis JSON).""" + return cls( + repo_id=data.get("repo_id", ""), + github_url=data.get("github_url", ""), + name=data.get("name", ""), + file_count=data.get("file_count", 0), + indexed_at=data.get("indexed_at", ""), + expires_at=data.get("expires_at", ""), + ) + + def is_expired(self) -> bool: + """Check if the indexed repo has expired.""" + try: + expires = datetime.fromisoformat(self.expires_at.replace("Z", "+00:00")) + return datetime.now(timezone.utc) > expires + except (ValueError, AttributeError): + return True # Treat parse errors as expired + + +@dataclass +class SessionData: + """ + Complete session state for the /session endpoint. + + Returned by get_session_data() method. + """ + session_id: Optional[str] = None + searches_used: int = 0 + created_at: Optional[datetime] = None + expires_at: Optional[datetime] = None + indexed_repo: Optional[Dict[str, Any]] = None + + def to_response(self, limit: int) -> dict: + """ + Convert to API response format. + + Matches the schema defined in issue #127. + """ + return { + "session_id": self._truncate_id(self.session_id) if self.session_id else None, + "created_at": self.created_at.isoformat() if self.created_at else None, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "indexed_repo": self.indexed_repo, + "searches": { + "used": self.searches_used, + "limit": limit, + "remaining": max(0, limit - self.searches_used), + }, + } + + @staticmethod + def _truncate_id(session_id: str) -> str: + """Truncate session ID for display (security: don't expose full token).""" + if len(session_id) > 12: + return f"{session_id[:8]}..." + return session_id + + +# ============================================================================= +# MAIN CLASS +# ============================================================================= + class PlaygroundLimiter: """ - Redis-backed rate limiter for playground searches. - + Redis-backed rate limiter and session manager for playground. + + Provides: + - Rate limiting (searches per session/IP) + - Session data management (indexed repos, search counts) + - Global circuit breaker (cost protection) + Usage: limiter = PlaygroundLimiter(redis_client) - - # Check before search + + # Rate limiting result = limiter.check_and_record(session_token, client_ip) if not result.allowed: raise HTTPException(429, result.reason) - - # Set cookie if new session - if result.session_token: - response.set_cookie("pg_session", result.session_token, ...) + + # Session management (#127) + session_data = limiter.get_session_data(session_token) + limiter.set_indexed_repo(session_token, repo_data) + has_repo = limiter.has_indexed_repo(session_token) """ - - # Limits + + # ------------------------------------------------------------------------- + # Configuration + # ------------------------------------------------------------------------- + + # Rate limits SESSION_LIMIT_PER_DAY = 50 # Per device (generous for conversion) IP_LIMIT_PER_DAY = 100 # Per IP (higher for shared networks) GLOBAL_LIMIT_PER_HOUR = 10000 # Circuit breaker (cost protection) - + + # Anonymous indexing limits (#114) + ANON_MAX_FILES = 200 # Max files for anonymous indexing + ANON_REPOS_PER_SESSION = 1 # Max repos per anonymous session + # Redis key prefixes KEY_SESSION = "playground:session:" KEY_IP = "playground:ip:" KEY_GLOBAL = "playground:global:hourly" - + + # Redis hash fields (for session data) + FIELD_SEARCHES = "searches_used" + FIELD_CREATED = "created_at" + FIELD_INDEXED_REPO = "indexed_repo" + # TTLs TTL_DAY = 86400 # 24 hours TTL_HOUR = 3600 # 1 hour - + def __init__(self, redis_client=None): + """ + Initialize the limiter. + + Args: + redis_client: Redis client instance. If None, limiter fails open + (allows all requests - useful for development). + """ self.redis = redis_client - - def _get_midnight_utc(self) -> datetime: - """Get next midnight UTC for reset time""" - now = datetime.now(timezone.utc) - tomorrow = now.replace(hour=0, minute=0, second=0, microsecond=0) - if tomorrow <= now: - from datetime import timedelta - tomorrow += timedelta(days=1) - return tomorrow - - def _hash_ip(self, ip: str) -> str: - """Hash IP for privacy""" - return hashlib.sha256(ip.encode()).hexdigest()[:16] - - def _generate_session_token(self) -> str: - """Generate secure session token""" - return secrets.token_urlsafe(32) - + + # ------------------------------------------------------------------------- + # Session Data Methods (#127) + # ------------------------------------------------------------------------- + + def get_session_data(self, session_token: Optional[str]) -> SessionData: + """ + Get complete session data for display. + + Used by GET /playground/session endpoint. + + Args: + session_token: The session token from cookie (can be None) + + Returns: + SessionData with all session information + + Note: + Returns empty SessionData if token is None or session doesn't exist. + Does NOT create a new session - that's done by check_and_record(). + """ + if not session_token: + logger.debug("get_session_data called with no token") + return SessionData() + + if not self.redis: + logger.warning("Redis unavailable in get_session_data") + return SessionData(session_id=session_token) + + try: + with track_time("session_data_get"): + session_key = f"{self.KEY_SESSION}{session_token}" + + # Ensure we're reading from hash format (handles legacy migration) + self._ensure_hash_format(session_token) + + # Get all session fields + raw_data = self.redis.hgetall(session_key) + + if not raw_data: + logger.debug("Session not found", session_token=session_token[:8]) + return SessionData() + + # Parse the data (handle bytes from Redis) + data = self._decode_hash_data(raw_data) + + # Get TTL for expires_at calculation + ttl = self.redis.ttl(session_key) + expires_at = None + if ttl and ttl > 0: + expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl) + + # Parse created_at + created_at = None + if data.get(self.FIELD_CREATED): + try: + created_str = data[self.FIELD_CREATED].replace("Z", "+00:00") + created_at = datetime.fromisoformat(created_str) + except (ValueError, AttributeError): + pass + + # Parse indexed_repo JSON + indexed_repo = None + if data.get(self.FIELD_INDEXED_REPO): + try: + indexed_repo = json.loads(data[self.FIELD_INDEXED_REPO]) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse indexed_repo JSON", + session_token=session_token[:8]) + + # Build response + session_data = SessionData( + session_id=session_token, + searches_used=int(data.get(self.FIELD_SEARCHES, 0)), + created_at=created_at, + expires_at=expires_at, + indexed_repo=indexed_repo, + ) + + metrics.increment("session_data_retrieved") + logger.debug("Session data retrieved", + session_token=session_token[:8], + searches_used=session_data.searches_used, + has_repo=indexed_repo is not None) + + return session_data + + except Exception as e: + logger.error("Failed to get session data", + error=str(e), + session_token=session_token[:8] if session_token else None) + capture_exception(e, operation="get_session_data") + return SessionData(session_id=session_token) + + def set_indexed_repo(self, session_token: str, repo_data: dict) -> bool: + """ + Store indexed repository info in session. + + Called by POST /playground/index endpoint (#125) after successful indexing. + + Args: + session_token: The session token + repo_data: Dictionary with repo info (repo_id, github_url, name, etc.) + + Returns: + True if successful, False otherwise + + Note: + - Overwrites any existing indexed_repo + - Does not affect searches_used count + - repo_data should include: repo_id, github_url, name, file_count, + indexed_at, expires_at + """ + if not session_token: + logger.warning("set_indexed_repo called with no token") + return False + + if not self.redis: + logger.warning("Redis unavailable in set_indexed_repo") + return False + + try: + with track_time("session_repo_set"): + session_key = f"{self.KEY_SESSION}{session_token}" + + # Ensure hash format exists + self._ensure_hash_format(session_token) + + # Serialize repo data to JSON + repo_json = json.dumps(repo_data) + + # Store in hash (preserves other fields like searches_used) + self.redis.hset(session_key, self.FIELD_INDEXED_REPO, repo_json) + + metrics.increment("session_repo_indexed") + logger.info("Indexed repo stored in session", + session_token=session_token[:8], + repo_id=repo_data.get("repo_id"), + repo_name=repo_data.get("name")) + + return True + + except Exception as e: + logger.error("Failed to set indexed repo", + error=str(e), + session_token=session_token[:8]) + capture_exception(e, operation="set_indexed_repo") + return False + + def has_indexed_repo(self, session_token: str) -> bool: + """ + Check if session already has an indexed repository. + + Used by POST /playground/index endpoint (#125) to enforce + 1 repo per session limit. + + Args: + session_token: The session token + + Returns: + True if session has an indexed repo, False otherwise + """ + if not session_token or not self.redis: + return False + + try: + session_key = f"{self.KEY_SESSION}{session_token}" + + # Check if indexed_repo field exists in hash + exists = self.redis.hexists(session_key, self.FIELD_INDEXED_REPO) + + logger.debug("Checked for indexed repo", + session_token=session_token[:8], + has_repo=exists) + + return bool(exists) + + except Exception as e: + logger.error("Failed to check indexed repo", + error=str(e), + session_token=session_token[:8]) + capture_exception(e, operation="has_indexed_repo") + return False + + def clear_indexed_repo(self, session_token: str) -> bool: + """ + Remove indexed repository from session. + + Useful for cleanup or allowing user to index a different repo. + + Args: + session_token: The session token + + Returns: + True if successful, False otherwise + """ + if not session_token or not self.redis: + return False + + try: + session_key = f"{self.KEY_SESSION}{session_token}" + self.redis.hdel(session_key, self.FIELD_INDEXED_REPO) + + logger.info("Cleared indexed repo from session", + session_token=session_token[:8]) + metrics.increment("session_repo_cleared") + + return True + + except Exception as e: + logger.error("Failed to clear indexed repo", + error=str(e), + session_token=session_token[:8]) + capture_exception(e, operation="clear_indexed_repo") + return False + + def create_session(self, session_token: str) -> bool: + """ + Create a new session with initial data. + + Args: + session_token: The session token to create + + Returns: + True if successful, False otherwise + """ + if not session_token or not self.redis: + return False + + try: + session_key = f"{self.KEY_SESSION}{session_token}" + now = datetime.now(timezone.utc).isoformat() + + # Create hash with initial values + self.redis.hset(session_key, mapping={ + self.FIELD_SEARCHES: "0", + self.FIELD_CREATED: now, + }) + self.redis.expire(session_key, self.TTL_DAY) + + logger.info("Created new session", session_token=session_token[:8]) + metrics.increment("session_created") + + return True + + except Exception as e: + logger.error("Failed to create session", + error=str(e), + session_token=session_token[:8]) + capture_exception(e, operation="create_session") + return False + + # ------------------------------------------------------------------------- + # Rate Limiting Methods (existing, updated for hash storage) + # ------------------------------------------------------------------------- + def check_limit( - self, - session_token: Optional[str], + self, + session_token: Optional[str], client_ip: str ) -> PlaygroundLimitResult: """ Check rate limit without recording a search. + Use this for GET /playground/limits endpoint. """ return self._check_limits(session_token, client_ip, record=False) - + def check_and_record( - self, - session_token: Optional[str], + self, + session_token: Optional[str], client_ip: str ) -> PlaygroundLimitResult: """ Check rate limit AND record a search if allowed. + Use this for POST /playground/search endpoint. """ return self._check_limits(session_token, client_ip, record=True) - + def _check_limits( - self, - session_token: Optional[str], + self, + session_token: Optional[str], client_ip: str, record: bool = False ) -> PlaygroundLimitResult: """ Internal method to check all rate limit layers. - + Order of checks: 1. Global circuit breaker (protects cost) 2. Session-based limit (primary) - 3. IP-based limit (fallback) + 3. IP-based limit (fallback for new sessions) """ resets_at = self._get_midnight_utc() new_session_token = None - + # If no Redis, fail OPEN (allow all) if not self.redis: logger.warning("Redis not available, allowing playground search") @@ -138,12 +503,13 @@ def _check_limits( limit=self.SESSION_LIMIT_PER_DAY, resets_at=resets_at, ) - + try: # Layer 1: Global circuit breaker global_allowed, global_count = self._check_global_limit(record) if not global_allowed: logger.warning("Global circuit breaker triggered", count=global_count) + metrics.increment("rate_limit_global_blocked") return PlaygroundLimitResult( allowed=False, remaining=0, @@ -151,13 +517,15 @@ def _check_limits( resets_at=resets_at, reason="Service is experiencing high demand. Please try again later.", ) - + # Layer 2: Session-based limit (primary) if session_token: session_allowed, session_remaining = self._check_session_limit( session_token, record ) if session_allowed: + if record: + metrics.increment("search_recorded") return PlaygroundLimitResult( allowed=True, remaining=session_remaining, @@ -166,6 +534,7 @@ def _check_limits( ) else: # Session exhausted + metrics.increment("rate_limit_session_blocked") return PlaygroundLimitResult( allowed=False, remaining=0, @@ -173,14 +542,15 @@ def _check_limits( resets_at=resets_at, reason="Daily limit reached. Sign up for unlimited searches!", ) - + # No session token - create new one and check IP new_session_token = self._generate_session_token() - + # Layer 3: IP-based limit (for new sessions / fallback) ip_allowed, ip_remaining = self._check_ip_limit(client_ip, record) if not ip_allowed: # IP exhausted (likely abuse or shared network) + metrics.increment("rate_limit_ip_blocked") return PlaygroundLimitResult( allowed=False, remaining=0, @@ -188,13 +558,12 @@ def _check_limits( resets_at=resets_at, reason="Daily limit reached. Sign up for unlimited searches!", ) - - # New session allowed + + # New session allowed - initialize it if record: - # Initialize session counter - session_key = f"{self.KEY_SESSION}{new_session_token}" - self.redis.set(session_key, "1", ex=self.TTL_DAY) - + self._init_new_session(new_session_token) + metrics.increment("search_recorded") + return PlaygroundLimitResult( allowed=True, remaining=self.SESSION_LIMIT_PER_DAY - 1 if record else self.SESSION_LIMIT_PER_DAY, @@ -202,7 +571,7 @@ def _check_limits( resets_at=resets_at, session_token=new_session_token, ) - + except Exception as e: logger.error("Playground rate limit check failed", error=str(e)) capture_exception(e) @@ -213,9 +582,9 @@ def _check_limits( limit=self.SESSION_LIMIT_PER_DAY, resets_at=resets_at, ) - + def _check_global_limit(self, record: bool) -> Tuple[bool, int]: - """Check global circuit breaker""" + """Check global circuit breaker.""" try: if record: count = self.redis.incr(self.KEY_GLOBAL) @@ -223,61 +592,191 @@ def _check_global_limit(self, record: bool) -> Tuple[bool, int]: self.redis.expire(self.KEY_GLOBAL, self.TTL_HOUR) else: count = int(self.redis.get(self.KEY_GLOBAL) or 0) - + allowed = count <= self.GLOBAL_LIMIT_PER_HOUR return allowed, count except Exception as e: logger.error("Global limit check failed", error=str(e)) return True, 0 # Fail open - + def _check_session_limit( - self, - session_token: str, + self, + session_token: str, record: bool ) -> Tuple[bool, int]: - """Check session-based limit""" + """ + Check session-based limit using Redis Hash. + + Updated for #127 to use hash storage instead of simple strings. + """ try: session_key = f"{self.KEY_SESSION}{session_token}" - + + # Ensure hash format (handles legacy string migration) + self._ensure_hash_format(session_token) + if record: - count = self.redis.incr(session_key) + # Atomically increment searches_used field + count = self.redis.hincrby(session_key, self.FIELD_SEARCHES, 1) + + # Set TTL on first search (if not already set) if count == 1: + now = datetime.now(timezone.utc).isoformat() + self.redis.hset(session_key, self.FIELD_CREATED, now) self.redis.expire(session_key, self.TTL_DAY) else: - count = int(self.redis.get(session_key) or 0) - + # Just read current count + count_str = self.redis.hget(session_key, self.FIELD_SEARCHES) + count = int(count_str) if count_str else 0 + remaining = max(0, self.SESSION_LIMIT_PER_DAY - count) allowed = count <= self.SESSION_LIMIT_PER_DAY return allowed, remaining + except Exception as e: logger.error("Session limit check failed", error=str(e)) return True, self.SESSION_LIMIT_PER_DAY # Fail open - + def _check_ip_limit(self, client_ip: str, record: bool) -> Tuple[bool, int]: - """Check IP-based limit""" + """Check IP-based limit.""" try: ip_hash = self._hash_ip(client_ip) ip_key = f"{self.KEY_IP}{ip_hash}" - + if record: count = self.redis.incr(ip_key) if count == 1: self.redis.expire(ip_key, self.TTL_DAY) else: count = int(self.redis.get(ip_key) or 0) - + remaining = max(0, self.IP_LIMIT_PER_DAY - count) allowed = count <= self.IP_LIMIT_PER_DAY return allowed, remaining except Exception as e: logger.error("IP limit check failed", error=str(e)) return True, self.IP_LIMIT_PER_DAY # Fail open - + + # ------------------------------------------------------------------------- + # Helper Methods + # ------------------------------------------------------------------------- + + def _get_midnight_utc(self) -> datetime: + """Get next midnight UTC for reset time.""" + now = datetime.now(timezone.utc) + tomorrow = now.replace(hour=0, minute=0, second=0, microsecond=0) + if tomorrow <= now: + tomorrow += timedelta(days=1) + return tomorrow + + def _hash_ip(self, ip: str) -> str: + """Hash IP for privacy.""" + return hashlib.sha256(ip.encode()).hexdigest()[:16] + + def _generate_session_token(self) -> str: + """Generate secure session token.""" + return secrets.token_urlsafe(32) + + def _ensure_hash_format(self, session_token: str) -> None: + """ + Ensure session data is in hash format. + + Handles migration from legacy string format (just a counter) + to new hash format (searches_used + created_at + indexed_repo). + + This is called before any hash operations to maintain + backward compatibility with existing sessions. + """ + session_key = f"{self.KEY_SESSION}{session_token}" + + try: + key_type = self.redis.type(session_key) + + # Handle bytes response from some Redis clients + if isinstance(key_type, bytes): + key_type = key_type.decode('utf-8') + + if key_type == 'string': + # Legacy format - migrate to hash + logger.info("Migrating legacy session to hash format", + session_token=session_token[:8]) + + # Read old count + old_count = self.redis.get(session_key) + count = int(old_count) if old_count else 0 + + # Get TTL before delete + ttl = self.redis.ttl(session_key) + + # Delete old string key + self.redis.delete(session_key) + + # Create new hash with migrated data + now = datetime.now(timezone.utc).isoformat() + self.redis.hset(session_key, mapping={ + self.FIELD_SEARCHES: str(count), + self.FIELD_CREATED: now, + }) + + # Restore TTL + if ttl and ttl > 0: + self.redis.expire(session_key, ttl) + + metrics.increment("session_migrated") + logger.info("Session migrated successfully", + session_token=session_token[:8], + searches_migrated=count) + + except Exception as e: + # Don't fail the operation, just log + logger.warning("Session format check failed", + error=str(e), + session_token=session_token[:8]) + + def _init_new_session(self, session_token: str) -> None: + """ + Initialize a new session with hash structure. + + Called when creating a new session on first search. + """ + session_key = f"{self.KEY_SESSION}{session_token}" + now = datetime.now(timezone.utc).isoformat() + + self.redis.hset(session_key, mapping={ + self.FIELD_SEARCHES: "1", # First search + self.FIELD_CREATED: now, + }) + self.redis.expire(session_key, self.TTL_DAY) + + metrics.increment("session_created") + logger.debug("New session initialized", session_token=session_token[:8]) + + def _decode_hash_data(self, raw_data: dict) -> dict: + """ + Decode Redis hash data (handles bytes from some Redis clients). + + Args: + raw_data: Raw data from redis.hgetall() + + Returns: + Dictionary with string keys and values + """ + decoded = {} + for key, value in raw_data.items(): + # Decode key if bytes + if isinstance(key, bytes): + key = key.decode('utf-8') + # Decode value if bytes + if isinstance(value, bytes): + value = value.decode('utf-8') + decoded[key] = value + return decoded + def get_usage_stats(self) -> dict: - """Get current global usage stats (for monitoring)""" + """Get current global usage stats (for monitoring).""" if not self.redis: return {"global_hourly": 0, "redis_available": False} - + try: global_count = int(self.redis.get(self.KEY_GLOBAL) or 0) return { @@ -289,12 +788,15 @@ def get_usage_stats(self) -> dict: return {"error": str(e), "redis_available": False} -# Singleton instance +# ============================================================================= +# SINGLETON +# ============================================================================= + _playground_limiter: Optional[PlaygroundLimiter] = None def get_playground_limiter(redis_client=None) -> PlaygroundLimiter: - """Get or create PlaygroundLimiter instance""" + """Get or create PlaygroundLimiter instance.""" global _playground_limiter if _playground_limiter is None: _playground_limiter = PlaygroundLimiter(redis_client) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 96bda1b..52815ad 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -52,14 +52,39 @@ def mock_pinecone(): @pytest.fixture(scope="session", autouse=True) def mock_redis(): - """Mock Redis client globally""" + """ + Mock Redis client globally. + + Includes hash operations for session management (#127). + """ with patch('redis.Redis') as mock: redis_instance = MagicMock() + + # Connection redis_instance.ping.return_value = True + + # String operations (legacy + IP/global limits) redis_instance.get.return_value = None redis_instance.set.return_value = True redis_instance.incr.return_value = 1 + redis_instance.delete.return_value = 1 + + # TTL operations redis_instance.expire.return_value = True + redis_instance.ttl.return_value = 86400 + + # Hash operations (session management #127) + redis_instance.type.return_value = b'hash' + redis_instance.hset.return_value = 1 + redis_instance.hget.return_value = b'0' + redis_instance.hgetall.return_value = { + b'searches_used': b'0', + b'created_at': b'2025-12-24T10:00:00Z', + } + redis_instance.hincrby.return_value = 1 + redis_instance.hexists.return_value = False + redis_instance.hdel.return_value = 1 + mock.return_value = redis_instance yield mock diff --git a/backend/tests/test_playground_limiter.py b/backend/tests/test_playground_limiter.py new file mode 100644 index 0000000..08d4b6a --- /dev/null +++ b/backend/tests/test_playground_limiter.py @@ -0,0 +1,519 @@ +""" +Test Suite for Playground Session Management +Issue #127 - Anonymous session management + +Tests: +- Session data retrieval +- Indexed repo storage +- Legacy session migration +- Rate limiting with hash storage +""" +import pytest +import json +from datetime import datetime, timezone, timedelta +from unittest.mock import MagicMock + +from services.playground_limiter import ( + PlaygroundLimiter, + SessionData, + IndexedRepoData, +) + + +# ============================================================================= +# FIXTURES +# ============================================================================= + +@pytest.fixture +def mock_redis(): + """Create a mock Redis client with hash support.""" + redis = MagicMock() + + # Default behaviors + redis.type.return_value = b'hash' + redis.hgetall.return_value = {} + redis.hget.return_value = None + redis.hset.return_value = 1 + redis.hincrby.return_value = 1 + redis.hexists.return_value = False + redis.hdel.return_value = 1 + redis.ttl.return_value = 86400 + redis.expire.return_value = True + redis.delete.return_value = 1 + redis.get.return_value = None + redis.incr.return_value = 1 + + return redis + + +@pytest.fixture +def limiter(mock_redis): + """Create a PlaygroundLimiter with mocked Redis.""" + return PlaygroundLimiter(redis_client=mock_redis) + + +@pytest.fixture +def limiter_no_redis(): + """Create a PlaygroundLimiter without Redis (fail-open mode).""" + return PlaygroundLimiter(redis_client=None) + + +@pytest.fixture +def sample_indexed_repo(): + """Sample indexed repo data.""" + return { + "repo_id": "repo_abc123", + "github_url": "https://github.com/pallets/flask", + "name": "flask", + "file_count": 198, + "indexed_at": "2025-12-24T10:05:00Z", + "expires_at": "2025-12-25T10:05:00Z", + } + + +# ============================================================================= +# DATA CLASS TESTS +# ============================================================================= + +class TestIndexedRepoData: + """Tests for IndexedRepoData dataclass.""" + + def test_from_dict(self, sample_indexed_repo): + """Should create IndexedRepoData from dictionary.""" + repo = IndexedRepoData.from_dict(sample_indexed_repo) + + assert repo.repo_id == "repo_abc123" + assert repo.name == "flask" + assert repo.file_count == 198 + + def test_to_dict(self, sample_indexed_repo): + """Should convert to dictionary.""" + repo = IndexedRepoData.from_dict(sample_indexed_repo) + result = repo.to_dict() + + assert result == sample_indexed_repo + + def test_is_expired_false(self): + """Should return False for non-expired repo.""" + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + repo = IndexedRepoData( + repo_id="abc", + github_url="https://github.com/user/repo", + name="repo", + file_count=100, + indexed_at="2025-12-24T10:00:00Z", + expires_at=future, + ) + + assert repo.is_expired() is False + + def test_is_expired_true(self): + """Should return True for expired repo.""" + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + repo = IndexedRepoData( + repo_id="abc", + github_url="https://github.com/user/repo", + name="repo", + file_count=100, + indexed_at="2025-12-24T10:00:00Z", + expires_at=past, + ) + + assert repo.is_expired() is True + + +class TestSessionData: + """Tests for SessionData dataclass.""" + + def test_to_response_empty(self): + """Should format empty session correctly.""" + session = SessionData() + result = session.to_response(limit=50) + + assert result["session_id"] is None + assert result["indexed_repo"] is None + assert result["searches"]["used"] == 0 + assert result["searches"]["limit"] == 50 + assert result["searches"]["remaining"] == 50 + + def test_to_response_with_data(self, sample_indexed_repo): + """Should format session with data correctly.""" + now = datetime.now(timezone.utc) + session = SessionData( + session_id="abc123def456ghi789", + searches_used=12, + created_at=now, + expires_at=now + timedelta(days=1), + indexed_repo=sample_indexed_repo, + ) + result = session.to_response(limit=50) + + assert result["session_id"] == "abc123de..." # Truncated (first 8 chars) + assert result["searches"]["used"] == 12 + assert result["searches"]["remaining"] == 38 + assert result["indexed_repo"] == sample_indexed_repo + + def test_truncate_id(self): + """Should truncate long session IDs.""" + assert SessionData._truncate_id("short") == "short" + assert SessionData._truncate_id("verylongsessiontoken123") == "verylong..." + + +# ============================================================================= +# GET SESSION DATA TESTS +# ============================================================================= + +class TestGetSessionData: + """Tests for get_session_data() method.""" + + def test_no_token_returns_empty(self, limiter): + """Should return empty SessionData when token is None.""" + result = limiter.get_session_data(None) + + assert result.session_id is None + assert result.searches_used == 0 + assert result.indexed_repo is None + + def test_no_redis_returns_session_id_only(self, limiter_no_redis): + """Should return partial data when Redis unavailable.""" + result = limiter_no_redis.get_session_data("some_token") + + assert result.session_id == "some_token" + assert result.searches_used == 0 + + def test_session_not_found(self, limiter, mock_redis): + """Should return empty SessionData for non-existent session.""" + mock_redis.hgetall.return_value = {} + + result = limiter.get_session_data("nonexistent_token") + + assert result.session_id is None + + def test_session_with_searches(self, limiter, mock_redis): + """Should return correct search count.""" + mock_redis.hgetall.return_value = { + b'searches_used': b'15', + b'created_at': b'2025-12-24T10:00:00Z', + } + + result = limiter.get_session_data("valid_token") + + assert result.session_id == "valid_token" + assert result.searches_used == 15 + + def test_session_with_indexed_repo(self, limiter, mock_redis, sample_indexed_repo): + """Should parse indexed_repo JSON correctly.""" + mock_redis.hgetall.return_value = { + b'searches_used': b'5', + b'created_at': b'2025-12-24T10:00:00Z', + b'indexed_repo': json.dumps(sample_indexed_repo).encode(), + } + + result = limiter.get_session_data("token_with_repo") + + assert result.indexed_repo is not None + assert result.indexed_repo["repo_id"] == "repo_abc123" + assert result.indexed_repo["name"] == "flask" + + def test_invalid_indexed_repo_json(self, limiter, mock_redis): + """Should handle invalid JSON gracefully.""" + mock_redis.hgetall.return_value = { + b'searches_used': b'5', + b'indexed_repo': b'not valid json{{{', + } + + result = limiter.get_session_data("token") + + assert result.indexed_repo is None # Graceful fallback + assert result.searches_used == 5 + + +# ============================================================================= +# SET INDEXED REPO TESTS +# ============================================================================= + +class TestSetIndexedRepo: + """Tests for set_indexed_repo() method.""" + + def test_no_token_returns_false(self, limiter): + """Should return False when token is None.""" + result = limiter.set_indexed_repo(None, {"repo_id": "abc"}) + + assert result is False + + def test_no_redis_returns_false(self, limiter_no_redis): + """Should return False when Redis unavailable.""" + result = limiter_no_redis.set_indexed_repo("token", {"repo_id": "abc"}) + + assert result is False + + def test_successful_set(self, limiter, mock_redis, sample_indexed_repo): + """Should store indexed repo successfully.""" + result = limiter.set_indexed_repo("valid_token", sample_indexed_repo) + + assert result is True + mock_redis.hset.assert_called() + + # Verify the JSON was stored + call_args = mock_redis.hset.call_args + stored_json = call_args[0][2] # Third argument is the value + stored_data = json.loads(stored_json) + assert stored_data["repo_id"] == "repo_abc123" + + def test_preserves_other_fields(self, limiter, mock_redis, sample_indexed_repo): + """Should not overwrite other session fields.""" + # Verify we use hset (field-level) not set (full replace) + limiter.set_indexed_repo("token", sample_indexed_repo) + + # Should call hset, not set + assert mock_redis.hset.called + assert not mock_redis.set.called + + +# ============================================================================= +# HAS INDEXED REPO TESTS +# ============================================================================= + +class TestHasIndexedRepo: + """Tests for has_indexed_repo() method.""" + + def test_no_token_returns_false(self, limiter): + """Should return False when token is None.""" + assert limiter.has_indexed_repo(None) is False + assert limiter.has_indexed_repo("") is False + + def test_no_redis_returns_false(self, limiter_no_redis): + """Should return False when Redis unavailable.""" + assert limiter_no_redis.has_indexed_repo("token") is False + + def test_repo_exists(self, limiter, mock_redis): + """Should return True when repo exists.""" + mock_redis.hexists.return_value = True + + assert limiter.has_indexed_repo("token") is True + mock_redis.hexists.assert_called_with( + "playground:session:token", + "indexed_repo" + ) + + def test_repo_not_exists(self, limiter, mock_redis): + """Should return False when repo doesn't exist.""" + mock_redis.hexists.return_value = False + + assert limiter.has_indexed_repo("token") is False + + +# ============================================================================= +# CLEAR INDEXED REPO TESTS +# ============================================================================= + +class TestClearIndexedRepo: + """Tests for clear_indexed_repo() method.""" + + def test_successful_clear(self, limiter, mock_redis): + """Should clear indexed repo successfully.""" + result = limiter.clear_indexed_repo("valid_token") + + assert result is True + mock_redis.hdel.assert_called_with( + "playground:session:valid_token", + "indexed_repo" + ) + + def test_no_token_returns_false(self, limiter): + """Should return False when token is None.""" + assert limiter.clear_indexed_repo(None) is False + + +# ============================================================================= +# LEGACY MIGRATION TESTS +# ============================================================================= + +class TestLegacyMigration: + """Tests for legacy string format migration.""" + + def test_migrate_string_to_hash(self, limiter, mock_redis): + """Should migrate legacy string format to hash.""" + # Simulate legacy string format + mock_redis.type.return_value = b'string' + mock_redis.get.return_value = b'25' # Old count + mock_redis.ttl.return_value = 3600 + + limiter._ensure_hash_format("legacy_token") + + # Verify migration happened + mock_redis.delete.assert_called() + mock_redis.hset.assert_called() + mock_redis.expire.assert_called_with( + "playground:session:legacy_token", + 3600 + ) + + def test_already_hash_no_migration(self, limiter, mock_redis): + """Should not migrate if already hash format.""" + mock_redis.type.return_value = b'hash' + + limiter._ensure_hash_format("hash_token") + + # Should NOT call delete (no migration needed) + mock_redis.delete.assert_not_called() + + def test_nonexistent_key_no_error(self, limiter, mock_redis): + """Should handle non-existent keys gracefully.""" + mock_redis.type.return_value = b'none' + + # Should not raise + limiter._ensure_hash_format("new_token") + + +# ============================================================================= +# RATE LIMITING WITH HASH STORAGE TESTS +# ============================================================================= + +class TestRateLimitingWithHash: + """Tests to verify rate limiting still works with hash storage.""" + + def test_check_and_record_uses_hincrby(self, limiter, mock_redis): + """check_and_record should use HINCRBY for atomic increment.""" + mock_redis.hincrby.return_value = 5 + + result = limiter.check_and_record("token", "127.0.0.1") + + assert result.allowed is True + assert result.remaining == 45 # 50 - 5 + mock_redis.hincrby.assert_called() + + def test_session_limit_enforced(self, limiter, mock_redis): + """Should enforce session limit with hash storage.""" + mock_redis.hincrby.return_value = 51 # Over limit + + result = limiter.check_and_record("token", "127.0.0.1") + + assert result.allowed is False + assert result.remaining == 0 + + def test_new_session_creates_hash(self, limiter, mock_redis): + """New sessions should be created with hash structure.""" + mock_redis.type.return_value = b'none' # Key doesn't exist + + result = limiter.check_and_record(None, "127.0.0.1") + + assert result.allowed is True + assert result.session_token is not None + # Should create hash with hset + mock_redis.hset.assert_called() + + +# ============================================================================= +# CREATE SESSION TESTS +# ============================================================================= + +class TestCreateSession: + """Tests for create_session() method.""" + + def test_successful_create(self, limiter, mock_redis): + """Should create session with initial values.""" + result = limiter.create_session("new_token") + + assert result is True + mock_redis.hset.assert_called() + mock_redis.expire.assert_called_with( + "playground:session:new_token", + 86400 # TTL_DAY + ) + + def test_no_token_returns_false(self, limiter): + """Should return False for empty token.""" + assert limiter.create_session(None) is False + assert limiter.create_session("") is False + + def test_no_redis_returns_false(self, limiter_no_redis): + """Should return False when Redis unavailable.""" + assert limiter_no_redis.create_session("token") is False + + +# ============================================================================= +# HELPER METHOD TESTS +# ============================================================================= + +class TestHelperMethods: + """Tests for helper methods.""" + + def test_decode_hash_data_bytes(self, limiter): + """Should decode bytes from Redis.""" + raw = {b'key1': b'value1', b'key2': b'value2'} + result = limiter._decode_hash_data(raw) + + assert result == {'key1': 'value1', 'key2': 'value2'} + + def test_decode_hash_data_strings(self, limiter): + """Should handle already-decoded strings.""" + raw = {'key1': 'value1', 'key2': 'value2'} + result = limiter._decode_hash_data(raw) + + assert result == {'key1': 'value1', 'key2': 'value2'} + + def test_generate_session_token(self, limiter): + """Should generate unique tokens.""" + token1 = limiter._generate_session_token() + token2 = limiter._generate_session_token() + + assert token1 != token2 + assert len(token1) > 20 # Should be reasonably long + + +# ============================================================================= +# INTEGRATION-STYLE TESTS +# ============================================================================= + +class TestSessionWorkflow: + """End-to-end workflow tests.""" + + def test_full_session_lifecycle(self, mock_redis): + """Test complete session lifecycle: create → search → index → search.""" + limiter = PlaygroundLimiter(redis_client=mock_redis) + + # 1. First search creates session + mock_redis.hincrby.return_value = 1 + result = limiter.check_and_record(None, "127.0.0.1") + assert result.allowed is True + token = result.session_token + assert token is not None + + # 2. Check session data + mock_redis.hgetall.return_value = { + b'searches_used': b'1', + b'created_at': b'2025-12-24T10:00:00Z', + } + session = limiter.get_session_data(token) + assert session.searches_used == 1 + assert session.indexed_repo is None + + # 3. User has no repo yet + mock_redis.hexists.return_value = False + assert limiter.has_indexed_repo(token) is False + + # 4. Index a repo + repo_data = { + "repo_id": "repo_123", + "github_url": "https://github.com/user/repo", + "name": "repo", + "file_count": 150, + "indexed_at": "2025-12-24T10:05:00Z", + "expires_at": "2025-12-25T10:05:00Z", + } + assert limiter.set_indexed_repo(token, repo_data) is True + + # 5. Now has repo + mock_redis.hexists.return_value = True + assert limiter.has_indexed_repo(token) is True + + # 6. More searches + mock_redis.hincrby.return_value = 10 + result = limiter.check_and_record(token, "127.0.0.1") + assert result.allowed is True + assert result.remaining == 40 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/docs/HANDOFF-114.md b/docs/HANDOFF-114.md new file mode 100644 index 0000000..a938fc6 --- /dev/null +++ b/docs/HANDOFF-114.md @@ -0,0 +1,60 @@ +# Handoff: Anonymous Indexing (#114) + +## TL;DR +Let users index their own GitHub repos without signup. 5 backend endpoints needed. + +## GitHub Issues (Full Specs) +- **#124** - Validate GitHub URL +- **#125** - Start anonymous indexing +- **#126** - Get indexing status +- **#127** - Extend session management +- **#128** - Update search for user repos + +**Read these first.** Each has request/response schemas, implementation notes, acceptance criteria. + +## Order of Work +``` +#127 + #124 (parallel) → #125 → #126 → #128 +``` + +## Key Files to Understand + +| File | What It Does | +|------|--------------| +| `backend/config/api.py` | API versioning (`/api/v1/*`) | +| `backend/routes/playground.py` | Existing playground endpoints | +| `backend/services/playground_limiter.py` | Session + rate limiting | +| `backend/services/repo_validator.py` | File counting, extensions | +| `backend/dependencies.py` | Indexer, cache, redis_client | + +## Constraints (Anonymous Users) +- 200 files max +- 1 repo per session +- 50 searches per session +- 24hr TTL + +## Workflow +See `CONTRIBUTING.md` for full guide. + +**Quick version:** +```bash +# Create branch +git checkout -b feat/124-validate-repo + +# Make changes, test +pytest tests/ -v + +# Commit +git add . +git commit -m "feat(playground): add validate-repo endpoint" + +# Push to YOUR fork +git push origin feat/124-validate-repo + +# Create PR on OpenCodeIntel/opencodeintel +# Reference issue: "Closes #124" +``` + +## Questions? +- Check GitHub issues first +- Ping Devanshu for blockers