diff --git a/backend/dependencies.py b/backend/dependencies.py index 0a2e93b..9ccff25 100644 --- a/backend/dependencies.py +++ b/backend/dependencies.py @@ -18,6 +18,7 @@ from services.supabase_service import get_supabase_service from services.input_validator import InputValidator, CostController from services.user_limits import init_user_limits_service, get_user_limits_service +from services.repo_validator import get_repo_validator # Service instances (singleton pattern) indexer = OptimizedCodeIndexer() @@ -38,6 +39,9 @@ redis_client=cache.redis if cache.redis else None ) +# Repository size validation +repo_validator = get_repo_validator() + def get_repo_or_404(repo_id: str, user_id: str) -> dict: """ diff --git a/backend/routes/repos.py b/backend/routes/repos.py index 42041ce..59d8454 100644 --- a/backend/routes/repos.py +++ b/backend/routes/repos.py @@ -8,7 +8,7 @@ from dependencies import ( indexer, repo_manager, metrics, - get_repo_or_404, cost_controller + get_repo_or_404, user_limits, repo_validator ) from services.input_validator import InputValidator from middleware.auth import require_auth, AuthContext @@ -38,9 +38,13 @@ async def add_repository( request: AddRepoRequest, auth: AuthContext = Depends(require_auth) ): - """Add a new repository with validation and cost controls.""" + """Add a new repository with validation and tier-based limits.""" user_id = auth.user_id or auth.identifier + # Validate user_id + if not user_id: + raise HTTPException(status_code=401, detail="User ID required") + # Validate inputs valid_name, name_error = InputValidator.validate_repo_name(request.name) if not valid_name: @@ -50,13 +54,17 @@ async def add_repository( if not valid_url: raise HTTPException(status_code=400, detail=f"Invalid Git URL: {url_error}") - # Check repo limit - user_id_hash = hashlib.sha256(user_id.encode()).hexdigest() - can_add, limit_error = cost_controller.check_repo_limit(user_id, user_id_hash) - if not can_add: - raise HTTPException(status_code=429, detail=limit_error) + # Check repo count limit (tier-aware) - #95 + repo_count_check = user_limits.check_repo_count(user_id) + if not repo_count_check.allowed: + raise HTTPException( + status_code=403, + detail=repo_count_check.to_dict() + ) try: + # Clone repo first + user_id_hash = hashlib.sha256(user_id.encode()).hexdigest() repo = repo_manager.add_repo( name=request.name, git_url=request.git_url, @@ -65,22 +73,60 @@ async def add_repository( api_key_hash=user_id_hash ) - # Check repo size - can_index, size_error = cost_controller.check_repo_size_limit(repo["local_path"]) - if not can_index: + # Analyze repo size - #94 + analysis = repo_validator.analyze_repo(repo["local_path"]) + + # Fail CLOSED if analysis failed (security: don't allow unknown-size repos) + if not analysis.success: + logger.error( + "Repo analysis failed - blocking indexing", + user_id=user_id, + repo_id=repo["id"], + error=analysis.error + ) + return { + "repo_id": repo["id"], + "status": "added", + "indexing_blocked": True, + "analysis": analysis.to_dict(), + "message": f"Repository added but analysis failed: {analysis.error}. Please try re-indexing later." + } + + # Check repo size against tier limits + size_check = user_limits.check_repo_size( + user_id, + analysis.file_count, + analysis.estimated_functions + ) + + if not size_check.allowed: + # Repo added but too large - return warning with upgrade CTA + logger.info( + "Repo too large for user tier", + user_id=user_id, + repo_id=repo["id"], + file_count=analysis.file_count, + tier=size_check.tier + ) return { "repo_id": repo["id"], "status": "added", - "warning": size_error, - "message": "Repository added but too large for automatic indexing" + "indexing_blocked": True, + "analysis": analysis.to_dict(), + "limit_check": size_check.to_dict(), + "message": size_check.message } return { "repo_id": repo["id"], "status": "added", - "message": "Repository added successfully" + "indexing_blocked": False, + "analysis": analysis.to_dict(), + "message": "Repository added successfully. Ready for indexing." } except Exception as e: + logger.error("Failed to add repository", error=str(e), user_id=user_id) + capture_exception(e) raise HTTPException(status_code=400, detail=str(e)) @@ -90,11 +136,48 @@ async def index_repository( incremental: bool = True, auth: AuthContext = Depends(require_auth) ): - """Trigger indexing for a repository.""" + """Trigger indexing for a repository with tier-based size limits.""" start_time = time.time() + user_id = auth.user_id + + # Validate user_id + if not user_id: + raise HTTPException(status_code=401, detail="User ID required") try: - repo = get_repo_or_404(repo_id, auth.user_id) + repo = get_repo_or_404(repo_id, user_id) + + # Re-check size limits before indexing (in case tier changed or repo updated) + analysis = repo_validator.analyze_repo(repo["local_path"]) + + # Fail CLOSED if analysis failed + if not analysis.success: + raise HTTPException( + status_code=500, + detail={ + "error": "ANALYSIS_FAILED", + "analysis": analysis.to_dict(), + "message": f"Cannot index: {analysis.error}" + } + ) + + size_check = user_limits.check_repo_size( + user_id, + analysis.file_count, + analysis.estimated_functions + ) + + if not size_check.allowed: + raise HTTPException( + status_code=403, + detail={ + "error": "REPO_TOO_LARGE", + "analysis": analysis.to_dict(), + "limit_check": size_check.to_dict(), + "message": size_check.message + } + ) + repo_manager.update_status(repo_id, "indexing") # Check for incremental @@ -132,7 +215,12 @@ async def index_repository( "index_type": index_type, "commit": current_commit[:8] } + except HTTPException: + raise except Exception as e: + logger.error("Indexing failed", repo_id=repo_id, error=str(e)) + capture_exception(e) + repo_manager.update_status(repo_id, "error") raise HTTPException(status_code=500, detail=str(e)) @@ -170,6 +258,24 @@ async def websocket_index(websocket: WebSocket, repo_id: str): await websocket.close(code=4004, reason="Repository not found") return + # Check size limits before WebSocket indexing + analysis = repo_validator.analyze_repo(repo["local_path"]) + + # Fail CLOSED if analysis failed + if not analysis.success: + await websocket.close(code=4005, reason=f"Analysis failed: {analysis.error}") + return + + size_check = user_limits.check_repo_size( + user_id, + analysis.file_count, + analysis.estimated_functions + ) + + if not size_check.allowed: + await websocket.close(code=4003, reason=size_check.message) + return + await websocket.accept() try: diff --git a/backend/services/repo_validator.py b/backend/services/repo_validator.py new file mode 100644 index 0000000..a115d85 --- /dev/null +++ b/backend/services/repo_validator.py @@ -0,0 +1,362 @@ +""" +Repository Validator Service +Analyzes repository size before indexing to enforce tier limits. + +Part of #94 (repo size limits) implementation. +""" +from pathlib import Path +from dataclasses import dataclass +from typing import Set, Optional +import random + +from services.observability import logger +from services.sentry import capture_exception + + +@dataclass +class RepoAnalysis: + """Result of repository analysis""" + file_count: int + estimated_functions: int + sampled: bool # True if we used sampling for large repos + error: Optional[str] = None # Error message if analysis failed + + @property + def success(self) -> bool: + """True if analysis completed without error""" + return self.error is None + + def to_dict(self) -> dict: + result = { + "file_count": self.file_count, + "estimated_functions": self.estimated_functions, + "sampled": self.sampled, + "success": self.success, + } + if self.error: + result["error"] = self.error + return result + + +class RepoValidator: + """ + Validates repository size before indexing. + + Usage: + validator = RepoValidator() + analysis = validator.analyze_repo("/path/to/repo") + + # Then check against user limits + result = user_limits.check_repo_size( + user_id, + analysis.file_count, + analysis.estimated_functions + ) + """ + + # Code file extensions we index + CODE_EXTENSIONS: Set[str] = { + '.py', # Python + '.js', # JavaScript + '.jsx', # React + '.ts', # TypeScript + '.tsx', # React TypeScript + '.go', # Go + '.rs', # Rust + '.java', # Java + '.rb', # Ruby + '.php', # PHP + '.c', # C + '.cpp', # C++ + '.h', # C/C++ headers + '.hpp', # C++ headers + '.cs', # C# + '.swift', # Swift + '.kt', # Kotlin + '.scala', # Scala + } + + # Directories to skip (common non-code dirs) + SKIP_DIRS: Set[str] = { + 'node_modules', + '.git', + '__pycache__', + '.pytest_cache', + 'venv', + 'env', + '.venv', + '.env', + 'dist', + 'build', + 'target', # Rust/Java build + '.next', # Next.js + '.nuxt', # Nuxt.js + 'vendor', # PHP/Go + 'coverage', + '.coverage', + 'htmlcov', + '.tox', + '.mypy_cache', + '.ruff_cache', + 'egg-info', + '.eggs', + } + + # Average functions per file by language (rough estimates) + # Used for quick estimation without parsing + AVG_FUNCTIONS_PER_FILE = 25 + + # Sample size for large repos + SAMPLE_SIZE = 100 + SAMPLE_THRESHOLD = 500 # Use sampling if more than this many files + + # Max file size to read (10MB) - prevent OOM on huge files + MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 + + def __init__(self): + pass + + def analyze_repo(self, repo_path: str) -> RepoAnalysis: + """ + Analyze repository to count files and estimate functions. + + For small repos: counts all files + For large repos: samples files for speed + + Args: + repo_path: Path to cloned repository + + Returns: + RepoAnalysis with file count and estimated function count + """ + try: + # Validate input + if not repo_path or not isinstance(repo_path, str) or not repo_path.strip(): + logger.warning("Invalid repo_path provided", repo_path=repo_path) + return RepoAnalysis( + file_count=0, + estimated_functions=0, + sampled=False, + error="Invalid repository path: empty or not a string" + ) + + # Validate path exists first + repo_root = Path(repo_path) + if not repo_root.exists(): + logger.warning("Repo path does not exist", repo_path=repo_path) + return RepoAnalysis( + file_count=0, + estimated_functions=0, + sampled=False, + error=f"Repository path does not exist: {repo_path}" + ) + + if not repo_root.is_dir(): + logger.warning("Repo path is not a directory", repo_path=repo_path) + return RepoAnalysis( + file_count=0, + estimated_functions=0, + sampled=False, + error=f"Repository path is not a directory: {repo_path}" + ) + + code_files, scan_error = self._find_code_files(repo_path) + + # Fail CLOSED if scan had errors (could have undercounted) + if scan_error: + logger.error("Repo scan incomplete", repo_path=repo_path, error=scan_error) + return RepoAnalysis( + file_count=len(code_files), + estimated_functions=0, + sampled=False, + error=f"Scan incomplete: {scan_error}" + ) + + file_count = len(code_files) + + if file_count == 0: + return RepoAnalysis( + file_count=0, + estimated_functions=0, + sampled=False + ) + + # For small repos, estimate directly + if file_count <= self.SAMPLE_THRESHOLD: + estimated_functions = file_count * self.AVG_FUNCTIONS_PER_FILE + return RepoAnalysis( + file_count=file_count, + estimated_functions=estimated_functions, + sampled=False + ) + + # For large repos, sample and extrapolate + sample = random.sample(code_files, min(self.SAMPLE_SIZE, file_count)) + sample_functions = self._count_functions_in_files(sample) + + # Extrapolate to full repo + avg_per_sampled = sample_functions / len(sample) + estimated_functions = int(avg_per_sampled * file_count) + + logger.info( + "Repo analysis complete (sampled)", + repo_path=repo_path, + file_count=file_count, + sample_size=len(sample), + estimated_functions=estimated_functions + ) + + return RepoAnalysis( + file_count=file_count, + estimated_functions=estimated_functions, + sampled=True + ) + + except Exception as e: + logger.error("Repo analysis failed", repo_path=repo_path, error=str(e)) + capture_exception(e) + # Return error - caller should fail CLOSED (block indexing) + return RepoAnalysis( + file_count=0, + estimated_functions=0, + sampled=False, + error=f"Analysis failed: {str(e)}" + ) + + def _find_code_files(self, repo_path: str) -> tuple[list[Path], Optional[str]]: + """ + Find all code files in repository (assumes path validated by caller). + + Returns: + Tuple of (code_files, error_message) + If error_message is not None, the scan was incomplete + """ + code_files = [] + repo_root = Path(repo_path) + scan_error = None + + try: + for file_path in repo_root.rglob('*'): + # Skip directories + if file_path.is_dir(): + continue + + # Skip symlinks (security: prevent scanning outside repo) + if file_path.is_symlink(): + continue + + # Skip files in excluded directories + if any(skip_dir in file_path.parts for skip_dir in self.SKIP_DIRS): + continue + + # Check extension + if file_path.suffix.lower() in self.CODE_EXTENSIONS: + code_files.append(file_path) + + except PermissionError as e: + logger.warning("Permission denied during repo scan", error=str(e)) + scan_error = f"Permission denied: {str(e)}" + except Exception as e: + logger.error("Error scanning repo", error=str(e)) + capture_exception(e) + scan_error = f"Scan failed: {str(e)}" + + return code_files, scan_error + + def _count_functions_in_files(self, files: list[Path]) -> int: + """ + Count approximate function definitions in files. + + Uses simple heuristics (not full AST parsing) for speed: + - Python: 'def ' and 'class ' + - JS/TS: 'function ', '=>', 'class ' + - etc. + + Security: Skips files larger than MAX_FILE_SIZE_BYTES to prevent OOM. + """ + total = 0 + + for file_path in files: + try: + # Security: Skip huge files to prevent OOM + file_size = file_path.stat().st_size + if file_size > self.MAX_FILE_SIZE_BYTES: + logger.debug("Skipping large file", path=str(file_path), size=file_size) + total += self.AVG_FUNCTIONS_PER_FILE # Estimate instead + continue + + content = file_path.read_text(encoding='utf-8', errors='ignore') + ext = file_path.suffix.lower() + + if ext == '.py': + # Count def and class + total += content.count('\ndef ') + content.count('\nclass ') + # Also count at file start + if content.startswith('def ') or content.startswith('class '): + total += 1 + + elif ext in {'.js', '.jsx', '.ts', '.tsx'}: + # Count function declarations and arrows + total += content.count('function ') + total += content.count('=>') + total += content.count('\nclass ') + + elif ext in {'.go'}: + total += content.count('\nfunc ') + + elif ext in {'.java', '.cs', '.kt', '.scala'}: + # Rough estimate - count method-like patterns + total += content.count('public ') + total += content.count('private ') + total += content.count('protected ') + + elif ext in {'.rb'}: + total += content.count('\ndef ') + total += content.count('\nclass ') + + elif ext in {'.rs'}: + total += content.count('\nfn ') + total += content.count('\nimpl ') + + elif ext in {'.c', '.cpp', '.h', '.hpp'}: + # Very rough - count open braces after parentheses + # This is imprecise but fast + total += content.count(') {') + + elif ext == '.php': + total += content.count('function ') + total += content.count('\nclass ') + + elif ext == '.swift': + total += content.count('\nfunc ') + total += content.count('\nclass ') + + else: + # Default estimate + total += self.AVG_FUNCTIONS_PER_FILE + + except Exception: + # If we can't read a file, use average + total += self.AVG_FUNCTIONS_PER_FILE + + return total + + def quick_file_count(self, repo_path: str) -> int: + """ + Quick file count without full analysis. + Useful for fast pre-checks. + """ + files, _ = self._find_code_files(repo_path) + return len(files) + + +# Singleton instance +_repo_validator: Optional[RepoValidator] = None + + +def get_repo_validator() -> RepoValidator: + """Get or create RepoValidator instance""" + global _repo_validator + if _repo_validator is None: + _repo_validator = RepoValidator() + return _repo_validator