diff --git a/backend/routes/repos.py b/backend/routes/repos.py index 80bc709..f3f99a4 100644 --- a/backend/routes/repos.py +++ b/backend/routes/repos.py @@ -1,7 +1,8 @@ """Repository management routes - CRUD and indexing.""" from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Depends, BackgroundTasks -from pydantic import BaseModel -from typing import Optional +from pydantic import BaseModel, field_validator +from typing import List, Optional +from pathlib import Path import hashlib import time import asyncio @@ -177,6 +178,56 @@ async def delete_repository( raise HTTPException(status_code=500, detail="Failed to delete repository") +def _scan_directories(local_path: Path) -> List[dict]: + """Scan top-level directories and count code files in each. + + Runs synchronously -- call via asyncio.to_thread() from async handlers + to avoid blocking the event loop on large repos. + """ + skip = {"node_modules", ".git", "__pycache__", "venv", ".next", "dist", "build"} + extensions = {".py", ".js", ".jsx", ".ts", ".tsx"} + dirs = [] + for item in sorted(local_path.iterdir()): + if item.is_dir() and item.name not in skip and not item.name.startswith("."): + file_count = sum( + 1 for f in item.rglob("*") + if f.is_file() and f.suffix in extensions + and not any(s in f.parts for s in skip) + ) + dirs.append({ + "name": item.name, + "path": str(item.relative_to(local_path)), + "file_count": file_count, + }) + return dirs + + +@router.get("/{repo_id}/directories") +async def get_repo_directories( + repo_id: str, + auth: AuthContext = Depends(require_auth), +) -> dict: + """Return the top-level directory tree of a cloned repo. + + Used for monorepo subset selection -- lets the user pick which + directories to index instead of the entire repo. + """ + repo = get_repo_or_404(repo_id, auth.user_id) + local_path = Path(repo["local_path"]) + + if not local_path.exists(): + raise HTTPException(status_code=404, detail="Repo not cloned yet") + + dirs = await asyncio.to_thread(_scan_directories, local_path) + + return { + "repo_id": repo_id, + "repo_name": repo.get("name", local_path.name), + "directories": dirs, + "total_directories": len(dirs), + } + + @router.post("/{repo_id}/index") async def index_repository( repo_id: str, @@ -275,7 +326,8 @@ async def _run_async_indexing( repo_id: str, repo: dict, user_id: str, - incremental: bool = True + incremental: bool = True, + include_paths: Optional[List[str]] = None, ): """ Background task for async indexing with real-time progress. @@ -298,9 +350,12 @@ async def _run_async_indexing( publisher.publish_progress(repo_id, 0, 1, 0, "Starting...") # Check for incremental + # Skip incremental when include_paths is set -- incremental_index_repository + # uses git diff which doesn't understand subset boundaries last_commit = repo_manager.get_last_indexed_commit(repo_id) + can_incremental = incremental and last_commit and not include_paths - if incremental and last_commit: + if can_incremental: logger.info("Async INCREMENTAL indexing", repo_id=repo_id, last_commit=last_commit[:8]) total_functions = await indexer.incremental_index_repository( repo_id, @@ -349,7 +404,8 @@ async def progress_callback( total_functions = await indexer.index_repository_with_progress( repo_id, repo["local_path"], - progress_callback + progress_callback, + include_paths=include_paths, ) total_files = tracked_total_files index_type = "full" @@ -400,11 +456,35 @@ async def progress_callback( ) +class IndexConfig(BaseModel): + """Optional config for indexing -- supports monorepo subset selection.""" + include_paths: Optional[List[str]] = None # e.g. ["packages/effect", "packages/schema"] + incremental: bool = True + + @field_validator("include_paths", mode="before") + @classmethod + def sanitize_paths(cls, v: Optional[List[str]]) -> Optional[List[str]]: + """Reject path traversal, empty strings, and normalize slashes.""" + if v is None: + return v + cleaned = [] + for item in v: + if not isinstance(item, str): + raise ValueError(f"include_paths entries must be strings, got {type(item).__name__}") + item = item.replace("\\", "/").strip().strip("/") + if not item: + raise ValueError("include_paths entries must not be empty") + if ".." in item.split("/"): + raise ValueError(f"Path traversal not allowed: {item}") + cleaned.append(item) + return cleaned + + @router.post("/{repo_id}/index/async", status_code=202) async def index_repository_async( repo_id: str, background_tasks: BackgroundTasks, - incremental: bool = True, + config: IndexConfig = IndexConfig(), auth: AuthContext = Depends(require_auth) ): """ @@ -463,14 +543,16 @@ async def index_repository_async( repo_id, repo, user_id, - incremental + incremental=config.incremental, + include_paths=config.include_paths, ) return { "status": "indexing", "repo_id": repo_id, "message": "Indexing started. Connect to WebSocket for progress.", - "websocket_url": f"/api/v1/ws/repos/{repo_id}/indexing" + "websocket_url": f"/api/v1/ws/repos/{repo_id}/indexing", + "include_paths": config.include_paths, } except HTTPException: @@ -500,7 +582,13 @@ async def _authenticate_websocket(websocket: WebSocket) -> Optional[dict]: # Note: WebSocket routes need to be registered on the main app, not router # This function is exported and called from main.py async def websocket_index(websocket: WebSocket, repo_id: str): - """Real-time repository indexing with progress updates.""" + """Real-time repository indexing with progress updates. + + NOTE: This WebSocket-direct-indexing path does NOT support include_paths + (monorepo subset selection). Use the HTTP async endpoint instead: + POST /repos/{id}/index/async with IndexConfig body. + This handler is the older pattern -- kept for backward compatibility. + """ user = await _authenticate_websocket(websocket) if not user: return diff --git a/backend/services/indexer_optimized.py b/backend/services/indexer_optimized.py index d8be0ed..60f9486 100644 --- a/backend/services/indexer_optimized.py +++ b/backend/services/indexer_optimized.py @@ -117,30 +117,44 @@ def _detect_language(self, file_path: str) -> Optional[str]: } return lang_map.get(ext) - def _discover_code_files(self, repo_path: str) -> List[Path]: - """Find all code files in repository""" + def _discover_code_files( + self, repo_path: str, include_paths: Optional[List[str]] = None + ) -> List[Path]: + """Find all code files in repository. + + Args: + include_paths: If set, only include files under these relative + directories (e.g. ['packages/effect', 'packages/schema']). + Uses path-component-aware matching and only walks the + specified subtrees instead of the entire repo. + """ repo_path = Path(repo_path) code_files = [] - - # Extensions to index + extensions = {'.py', '.js', '.jsx', '.ts', '.tsx'} - - # Directories to skip skip_dirs = {'node_modules', '.git', '__pycache__', 'venv', 'env', 'dist', 'build', '.next', '.vscode'} - - for file_path in repo_path.rglob('*'): - # Skip directories - if file_path.is_dir(): - continue - - # Skip if in excluded directory - if any(skip in file_path.parts for skip in skip_dirs): - continue - - # Check extension - if file_path.suffix in extensions: - code_files.append(file_path) - + + # When include_paths is set, only walk those subtrees + if include_paths: + roots = [] + for p in include_paths: + subtree = repo_path / p + if subtree.is_dir(): + roots.append(subtree) + else: + logger.warning("include_path not found, skipping: %s", p) + else: + roots = [repo_path] + + for root in roots: + for file_path in root.rglob('*'): + if file_path.is_dir(): + continue + if any(skip in file_path.parts for skip in skip_dirs): + continue + if file_path.suffix in extensions: + code_files.append(file_path) + return code_files async def _create_embeddings_batch(self, texts: List[str]) -> List[List[float]]: @@ -349,11 +363,16 @@ async def _extract_functions_from_file( logger.error("Error processing file", file_path=file_path, error=str(e)) return [] - def extract_functions_v2(self, repo_path: str, max_functions: int = 5000) -> List[ExtractedFunction]: + def extract_functions_v2( + self, repo_path: str, max_functions: int = 5000, + include_paths: Optional[List[str]] = None, + ) -> List[ExtractedFunction]: """Extract and filter functions using tree-sitter.""" from pathlib import Path - raw = self.tree_sitter_extractor.extract_from_repo(Path(repo_path), max_functions=max_functions) + raw = self.tree_sitter_extractor.extract_from_repo( + Path(repo_path), include_paths=include_paths, max_functions=max_functions, + ) filtered = self.function_filter.filter_functions(raw) logger.info("V2 extraction", total=len(raw), kept=len(filtered)) @@ -397,15 +416,17 @@ async def index_repository_v2( repo_id: str, repo_path: str, progress_callback=None, - generate_summaries: bool = False + generate_summaries: bool = False, + include_paths: Optional[List[str]] = None, ) -> int: """Index repository using V2 function-level extraction.""" from services.search_v2 import generate_summaries as gen_summaries start_time = time.time() - logger.info("V2 indexing started", repo_id=repo_id, with_summaries=generate_summaries) + logger.info("V2 indexing started", repo_id=repo_id, with_summaries=generate_summaries, + include_paths=include_paths) - functions = self.extract_functions_v2(repo_path) + functions = self.extract_functions_v2(repo_path, include_paths=include_paths) if not functions: if progress_callback: await progress_callback(0, 0, 0) @@ -691,18 +712,21 @@ async def index_repository_with_progress( repo_id: str, repo_path: str, progress_callback, - max_files: int = None - ): + max_files: int = None, + include_paths: Optional[List[str]] = None, + ) -> int: """Index repository with real-time progress updates Args: max_files: If set, limit indexing to first N files (for partial indexing) + include_paths: If set, only index files under these directories """ start_time = time.time() - logger.info("Starting optimized indexing with progress", repo_id=repo_id) + logger.info("Starting optimized indexing with progress", repo_id=repo_id, + include_paths=include_paths) - # Discover code files - code_files = self._discover_code_files(repo_path) + # Discover code files (filtered by include_paths if set) + code_files = self._discover_code_files(repo_path, include_paths=include_paths) # Apply file limit if specified (partial indexing) if max_files and len(code_files) > max_files: