diff --git a/mcp-server/.env.example b/mcp-server/.env.example index 0142166..6ddd573 100644 --- a/mcp-server/.env.example +++ b/mcp-server/.env.example @@ -1,3 +1,3 @@ # Backend API Configuration +API_KEY=your-api-key-here BACKEND_API_URL=http://localhost:8000 -API_KEY=dev-secret-key diff --git a/mcp-server/.gitignore b/mcp-server/.gitignore new file mode 100644 index 0000000..b455f67 --- /dev/null +++ b/mcp-server/.gitignore @@ -0,0 +1,11 @@ +# Python +__pycache__/ +*.pyc +*.pyo + +# Virtual environment +venv/ + +# Environment (secrets) +.env +.env.local diff --git a/mcp-server/api_client.py b/mcp-server/api_client.py new file mode 100644 index 0000000..a2fcfc3 --- /dev/null +++ b/mcp-server/api_client.py @@ -0,0 +1,68 @@ +"""Persistent HTTP client for backend API communication. + +Uses a module-level client to avoid creating new TCP connections per tool call. +The client is initialized lazily on first use and reused for all subsequent calls. +Concurrent access is serialized via asyncio.Lock to prevent duplicate clients. +""" +import asyncio +from typing import Any, Optional + +import httpx + +from config import BACKEND_API_URL, API_KEY + + +# Persistent client reused across all tool calls +_client: Optional[httpx.AsyncClient] = None +_client_lock: asyncio.Lock = asyncio.Lock() + + +def _get_headers() -> dict[str, str]: + """Return Authorization header with the configured API_KEY. + + Raises ValueError if API_KEY is empty or unset. + """ + if not API_KEY: + raise ValueError( + "No API_KEY configured. Set API_KEY in .env or environment." + ) + return {"Authorization": f"Bearer {API_KEY}"} + + +async def get_client() -> httpx.AsyncClient: + """Get or create the persistent HTTP client.""" + global _client + async with _client_lock: + if _client is None or _client.is_closed: + _client = httpx.AsyncClient( + base_url=BACKEND_API_URL, + timeout=120.0, + headers=_get_headers(), + ) + return _client + + +async def api_get(path: str, **kwargs: Any) -> dict: + """Make a GET request to the backend API.""" + client = await get_client() + response = await client.get(path, **kwargs) + response.raise_for_status() + return response.json() + + +async def api_post(path: str, json: dict, **kwargs: Any) -> dict: + """Make a POST request to the backend API.""" + client = await get_client() + response = await client.post(path, json=json, **kwargs) + response.raise_for_status() + return response.json() + + +async def close_client() -> None: + """Close the persistent client. Call on server shutdown.""" + global _client + async with _client_lock: + local = _client + _client = None + if local and not local.is_closed: + await local.aclose() diff --git a/mcp-server/config.py b/mcp-server/config.py index 4885fbb..9cafd13 100644 --- a/mcp-server/config.py +++ b/mcp-server/config.py @@ -1,19 +1,16 @@ -""" -API Configuration - Single Source of Truth for API Versioning +"""MCP server configuration from environment variables.""" +import os -Change API_VERSION here to update all API calls across the MCP server. -Example: "v1" -> "v2" will change /api/v1/* to /api/v2/* -""" +from dotenv import load_dotenv -# ============================================================================= -# API VERSION CONFIGURATION -# ============================================================================= +load_dotenv() API_VERSION = "v1" +API_PREFIX = f"/api/{API_VERSION}" -# ============================================================================= -# DERIVED PREFIXES (auto-calculated from version) -# ============================================================================= +BACKEND_BASE_URL = os.getenv("BACKEND_API_URL", "http://localhost:8000") +BACKEND_API_URL = f"{BACKEND_BASE_URL}{API_PREFIX}" +API_KEY = os.getenv("API_KEY", "") -# Current versioned API prefix: /api/v1 -API_PREFIX = f"/api/{API_VERSION}" +SERVER_NAME = "codeintel-mcp" +SERVER_VERSION = "0.4.0" diff --git a/mcp-server/formatters.py b/mcp-server/formatters.py new file mode 100644 index 0000000..8808629 --- /dev/null +++ b/mcp-server/formatters.py @@ -0,0 +1,196 @@ +"""Response formatters that convert API responses to markdown. + +Each formatter is a pure function: takes API response dict, returns markdown string. +This makes them independently testable without any HTTP calls. +""" + + +def format_search_results(result: dict) -> str: + """Format semantic search results as markdown. + + Supports both v1 (count/results) and v2 (total/results) response shapes + so the formatter stays resilient across API versions. + """ + total = result.get("total") or result.get("count", 0) + cached = " (cached)" if result.get("cached") else "" + version = result.get("search_version", "v1") + output = f"# Code Search Results ({version})\n\nFound {total} results{cached}\n\n" + + if not result.get("results"): + return output + "No results found.\n" + + for idx, res in enumerate(result["results"], 1): + score_raw = res.get("score") + try: + score = float(score_raw) * 100 + except (TypeError, ValueError): + score = 0 + name = res.get("name", "unknown") + file_path = res.get("file_path", "unknown") + lang = res.get("language", "unknown") + line_start = res.get("line_start", 0) + line_end = res.get("line_end", 0) + code = res.get("code", "") + + output += f"## {idx}. {name} ({score:.0f}% match)\n" + output += f"**File:** `{file_path}`\n" + + # v2 adds qualified_name and signature + qualified = res.get("qualified_name") + if qualified and qualified != name: + output += f"**Qualified:** `{qualified}`\n" + signature = res.get("signature") + if signature: + output += f"**Signature:** `{signature}`\n" + + output += f"**Language:** {lang} | **Lines:** {line_start}-{line_end}\n" + + reason = res.get("match_reason") + if reason: + output += f"**Why:** {reason}\n" + + output += f"\n```{lang}\n{code}\n```\n\n" + + return output + + +def format_repositories(result: dict) -> str: + """Format repository listing as markdown.""" + output = "# Indexed Repositories\n\n" + + if not result.get("repositories"): + return output + "No repositories indexed yet.\n" + + for repo in result["repositories"]: + output += f"### {repo.get('name', 'unknown')}\n" + output += f"- **ID:** `{repo.get('id')}`\n" + output += f"- **Status:** {repo.get('status', 'unknown')}\n" + output += f"- **Functions:** {repo.get('file_count', 0):,}\n" + output += f"- **Branch:** {repo.get('branch', 'main')}\n\n" + + return output + + +def format_dependency_graph(result: dict) -> str: + """Format dependency graph analysis as markdown.""" + nodes = result.get("nodes", []) + edges = result.get("edges", []) + metrics = result.get("metrics", {}) + + output = "# Dependency Graph Analysis\n\n" + output += f"**Total Files:** {len(nodes)}\n" + output += f"**Total Dependencies:** {metrics.get('total_edges', len(edges))}\n" + output += f"**Avg Dependencies per File:** {metrics.get('avg_dependencies', 0):.1f}\n\n" + + # Most-imported files (highest number of dependents) + dependent_count: dict[str, int] = {} + for edge in edges: + target = edge.get("target", "") + dependent_count[target] = dependent_count.get(target, 0) + 1 + + if dependent_count: + sorted_deps = sorted( + dependent_count.items(), key=lambda x: x[1], reverse=True + )[:5] + output += "## Most Critical Files (High Impact)\n\n" + for file, count in sorted_deps: + output += f"- `{file}` - **{count} dependents**\n" + output += "\n" + + high_import = [n for n in nodes if n.get("imports", 0) >= 3] + if high_import: + output += "## Files with Most Imports\n\n" + for f in sorted(high_import, key=lambda x: x.get("imports", 0), reverse=True)[:5]: + output += f"- `{f.get('id', '')}` - imports {f.get('imports', 0)} files\n" + + return output + + +def format_code_style(result: dict) -> str: + """Format code style analysis as markdown.""" + summary = result.get("summary", {}) + output = "# Code Style Analysis\n\n" + output += f"**Files Analyzed:** {summary.get('total_files_analyzed', 0)}\n" + output += f"**Functions:** {summary.get('total_functions', 0)}\n" + output += f"**Async Adoption:** {summary.get('async_adoption', '0%')}\n" + output += f"**Type Hints:** {summary.get('type_hints_usage', '0%')}\n\n" + + naming = result.get("naming_conventions", {}).get("functions") + if naming: + output += "## Function Naming Conventions\n\n" + for conv, info in naming.items(): + output += f"- **{conv}:** {info.get('percentage', '?')} ({info.get('count', 0)} functions)\n" + output += "\n" + + top_imports = result.get("top_imports") + if top_imports: + output += "## Most Common Imports\n\n" + for item in top_imports[:10]: + output += f"- `{item.get('module', '')}` (used {item.get('count', 0)}x)\n" + + return output + + +def format_impact_analysis(result: dict) -> str: + """Format file impact analysis as markdown.""" + output = f"# Impact Analysis: {result.get('file', 'unknown')}\n\n" + output += f"**Risk Level:** {result.get('risk_level', 'unknown').upper()}\n" + output += f"**Impact Summary:** {result.get('impact_summary', '')}\n\n" + + deps = result.get("direct_dependencies", []) + output += f"## Dependencies ({len(deps)})\n" + output += "Files this file imports:\n" + for dep in deps[:10]: + output += f"- `{dep}`\n" + output += "\n" + + dependents = result.get("all_dependents", []) + output += f"## Dependents ({len(dependents)})\n" + output += "Files that would be affected by changes:\n" + for dep in dependents[:15]: + output += f"- `{dep}`\n" + + test_files = result.get("test_files") + if test_files: + output += "\n## Related Tests\n" + for test in test_files: + output += f"- `{test}`\n" + + return output + + +def format_repository_insights(result: dict) -> str: + """Format repository insights as markdown.""" + output = f"# Repository Insights: {result.get('name', 'unknown')}\n\n" + output += f"**Status:** {result.get('status', 'unknown')}\n" + output += f"**Functions Indexed:** {result.get('functions_indexed', 0):,}\n" + output += f"**Total Files:** {result.get('total_files', 0)}\n" + output += f"**Total Dependencies:** {result.get('total_dependencies', 0)}\n\n" + + metrics = result.get("graph_metrics", {}) + critical = metrics.get("most_critical_files") + if critical: + output += "## Most Critical Files\n" + for item in critical[:5]: + output += f"- `{item.get('file', '')}` ({item.get('dependents', 0)} dependents)\n" + + return output + + +def format_codebase_dna(result: dict) -> str: + """Format codebase DNA extraction as markdown.""" + dna_markdown = result.get("dna", "") + cached = " (cached)" if result.get("cached") else "" + + output = f"# Codebase DNA{cached}\n\n" + output += "**Use this information to write code that matches existing patterns.**\n\n" + output += dna_markdown + output += "\n---\n" + output += "**Instructions:** When generating code for this codebase:\n" + output += "1. Follow the auth patterns shown above\n" + output += "2. Use the service layer structure (singletons in dependencies.py)\n" + output += "3. Match the database conventions (ID types, timestamps, RLS)\n" + output += "4. Use the logging patterns shown\n" + output += "5. Follow the naming conventions\n" + + return output diff --git a/mcp-server/handlers.py b/mcp-server/handlers.py new file mode 100644 index 0000000..5154bb8 --- /dev/null +++ b/mcp-server/handlers.py @@ -0,0 +1,126 @@ +"""Tool handler dispatch. + +Maps tool names to their API calls and response formatters. +Each handler follows the same pattern: call API, format response. +Error handling is centralized in call_tool() so individual handlers stay clean. +""" +import logging +from typing import Any + +import httpx +import mcp.types as types + +logger = logging.getLogger(__name__) + +from api_client import api_get, api_post +from formatters import ( + format_codebase_dna, + format_code_style, + format_dependency_graph, + format_impact_analysis, + format_repositories, + format_repository_insights, + format_search_results, +) + + +def _clamp_max_results(raw: Any) -> int: + """Validate and clamp max_results to [1, 100].""" + try: + value = int(raw) + except (TypeError, ValueError): + return 10 + return max(1, min(value, 100)) + + +async def _handle_search(args: dict[str, Any]) -> str: + # Map tool schema's max_results to v2 API's top_k + top_k = _clamp_max_results(args.get("max_results", 10)) + payload = { + "query": args["query"], + "repo_id": args["repo_id"], + "top_k": top_k, + "use_reranking": True, + } + result = await api_post("/search/v2", json=payload) + return format_search_results(result) + + +async def _handle_list_repositories(args: dict[str, Any]) -> str: + result = await api_get("/repos") + return format_repositories(result) + + +async def _handle_dependency_graph(args: dict[str, Any]) -> str: + result = await api_get(f"/repos/{args['repo_id']}/dependencies") + return format_dependency_graph(result) + + +async def _handle_code_style(args: dict[str, Any]) -> str: + result = await api_get(f"/repos/{args['repo_id']}/style-analysis") + return format_code_style(result) + + +async def _handle_impact(args: dict[str, Any]) -> str: + result = await api_post( + f"/repos/{args['repo_id']}/impact", + json={"repo_id": args["repo_id"], "file_path": args["file_path"]}, + ) + return format_impact_analysis(result) + + +async def _handle_insights(args: dict[str, Any]) -> str: + result = await api_get(f"/repos/{args['repo_id']}/insights") + return format_repository_insights(result) + + +async def _handle_dna(args: dict[str, Any]) -> str: + result = await api_get(f"/repos/{args['repo_id']}/dna?format=markdown") + return format_codebase_dna(result) + + +# Tool name -> handler mapping +_HANDLERS: dict[str, Any] = { + "search_code": _handle_search, + "list_repositories": _handle_list_repositories, + "get_dependency_graph": _handle_dependency_graph, + "analyze_code_style": _handle_code_style, + "analyze_impact": _handle_impact, + "get_repository_insights": _handle_insights, + "get_codebase_dna": _handle_dna, +} + + +def _safe_error_message(tool_name: str, args: dict[str, Any], error: Exception) -> str: + """Build error message with context but without leaking internal details.""" + repo_id = args.get("repo_id", "unknown") + if isinstance(error, httpx.HTTPStatusError): + status = error.response.status_code + return f"Backend returned {status} for tool '{tool_name}' (repo: {repo_id})" + if isinstance(error, httpx.TimeoutException): + return f"Request timed out for tool '{tool_name}' (repo: {repo_id})" + if isinstance(error, httpx.ConnectError): + return f"Cannot connect to backend for tool '{tool_name}'. Is the server running?" + if isinstance(error, ValueError): + logger.warning("ValueError in tool '%s' (repo: %s): %s", tool_name, repo_id, error) + return f"Tool input error for '{tool_name}' (repo: {repo_id})" + logger.exception("Unexpected error in tool '%s' (repo: %s)", tool_name, repo_id) + return f"Unexpected error in tool '{tool_name}' (repo: {repo_id})" + + +async def call_tool( + name: str, arguments: dict[str, Any] | None +) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + """Dispatch a tool call to the appropriate handler.""" + args = arguments or {} + + handler = _HANDLERS.get(name) + if handler is None: + return [types.TextContent(type="text", text=f"Unknown tool: {name}")] + + try: + text = await handler(args) + return [types.TextContent(type="text", text=text)] + except Exception as e: + msg = _safe_error_message(name, args, e) + return [types.TextContent(type="text", text=msg)] diff --git a/mcp-server/pyproject.toml b/mcp-server/pyproject.toml new file mode 100644 index 0000000..68f242a --- /dev/null +++ b/mcp-server/pyproject.toml @@ -0,0 +1,3 @@ +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/mcp-server/requirements.txt b/mcp-server/requirements.txt index 0c921a2..6c415b1 100644 --- a/mcp-server/requirements.txt +++ b/mcp-server/requirements.txt @@ -3,3 +3,7 @@ mcp>=1.0.0 httpx>=0.27.0 python-dotenv>=1.0.0 pydantic>=2.0.0 + +# Test Dependencies +pytest>=8.0.0 +pytest-asyncio>=0.23.0 diff --git a/mcp-server/server.py b/mcp-server/server.py index 9d74005..184e200 100644 --- a/mcp-server/server.py +++ b/mcp-server/server.py @@ -1,388 +1,53 @@ #!/usr/bin/env python3 -""" -CodeIntel MCP Server -Provides codebase intelligence tools for LLMs via Model Context Protocol +"""CodeIntel MCP Server entry point. + +Provides codebase intelligence tools for LLMs via Model Context Protocol. +All tool definitions, handlers, and formatters are in their own modules. """ import asyncio -import os -from typing import Any -import httpx from mcp.server import Server -from mcp.server.models import InitializationOptions +from mcp.server.models import InitializationOptions, ServerCapabilities import mcp.server.stdio import mcp.types as types -from dotenv import load_dotenv - -# Import API config (single source of truth for versioning) -from config import API_PREFIX -# Load environment variables -load_dotenv() +from config import SERVER_NAME, SERVER_VERSION +from tools import get_tool_schemas +from handlers import call_tool +from api_client import close_client -# Configuration -BACKEND_BASE_URL = os.getenv("BACKEND_API_URL", "http://localhost:8000") -BACKEND_API_URL = f"{BACKEND_BASE_URL}{API_PREFIX}" # Full versioned URL -API_KEY = os.getenv("API_KEY", "dev-secret-key") - -# Create MCP server instance -server = Server("codeintel-mcp") +server = Server(SERVER_NAME) @server.list_tools() async def handle_list_tools() -> list[types.Tool]: - """List available tools for codebase intelligence""" - return [ - types.Tool( - name="search_code", - description="Semantically search code in a repository. Finds code by meaning, not just keywords. Use this to find existing implementations, patterns, or specific functionality.", - inputSchema={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query (natural language or code snippet). Examples: 'authentication middleware', 'React hook for state', 'database connection pool'" - }, - "repo_id": { - "type": "string", - "description": "Repository identifier" - }, - "max_results": { - "type": "integer", - "description": "Maximum number of results to return (default: 10)", - "default": 10 - } - }, - "required": ["query", "repo_id"] - } - ), - types.Tool( - name="list_repositories", - description="List all indexed repositories available for analysis", - inputSchema={ - "type": "object", - "properties": {} - } - ), - types.Tool( - name="get_dependency_graph", - description="Get the complete dependency graph for a repository. Shows which files depend on which, identifies critical files, and reveals architecture patterns.", - inputSchema={ - "type": "object", - "properties": { - "repo_id": { - "type": "string", - "description": "Repository identifier" - } - }, - "required": ["repo_id"] - } - ), - types.Tool( - name="analyze_code_style", - description="Analyze team coding patterns and conventions. Returns naming conventions (snake_case vs camelCase), async usage, type hint usage, common imports, and coding patterns. Use this to match team style when generating code.", - inputSchema={ - "type": "object", - "properties": { - "repo_id": { - "type": "string", - "description": "Repository identifier" - } - }, - "required": ["repo_id"] - } - ), - types.Tool( - name="analyze_impact", - description="Analyze the impact of changing a specific file. Shows what files depend on it, what it depends on, risk level, and related test files. Critical for understanding change consequences.", - inputSchema={ - "type": "object", - "properties": { - "repo_id": { - "type": "string", - "description": "Repository identifier" - }, - "file_path": { - "type": "string", - "description": "Path to the file to analyze (relative to repo root)" - } - }, - "required": ["repo_id", "file_path"] - } - ), - types.Tool( - name="get_repository_insights", - description="Get comprehensive insights about a repository including dependency metrics, code style summary, and architecture overview. Use this for high-level codebase understanding.", - inputSchema={ - "type": "object", - "properties": { - "repo_id": { - "type": "string", - "description": "Repository identifier" - } - }, - "required": ["repo_id"] - } - ), - types.Tool( - name="get_codebase_dna", - description="Extract the architectural DNA of a codebase. Returns patterns, conventions, and constraints that define how code should be written. Use this BEFORE generating any code to understand: authentication patterns, service layer structure, database conventions (UUID vs SERIAL, RLS policies), error handling, logging patterns, naming conventions, and common imports. This ensures generated code matches existing architecture.", - inputSchema={ - "type": "object", - "properties": { - "repo_id": { - "type": "string", - "description": "Repository identifier" - } - }, - "required": ["repo_id"] - } - ) - ] + """Return all available tool schemas.""" + return get_tool_schemas() @server.call_tool() async def handle_call_tool( - name: str, arguments: dict[str, Any] | None + name: str, arguments: dict | None ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - """Handle tool execution""" - - if arguments is None: - arguments = {} - - try: - async with httpx.AsyncClient(timeout=120.0) as client: - headers = {"Authorization": f"Bearer {API_KEY}"} - - if name == "search_code": - response = await client.post( - f"{BACKEND_API_URL}/search", - json=arguments, - headers=headers - ) - response.raise_for_status() - result = response.json() - - # Format results - formatted = f"# Code Search Results\n\n" - formatted += f"Found {result.get('count', 0)} results" - if result.get('cached'): - formatted += " (⚡ cached)\n\n" - else: - formatted += "\n\n" - - if result.get("results"): - for idx, res in enumerate(result["results"], 1): - formatted += f"## {idx}. {res.get('name', 'unknown')} ({res.get('score', 0)*100:.0f}% match)\n" - formatted += f"**File:** `{res.get('file_path', 'unknown')}`\n" - formatted += f"**Type:** {res.get('type', 'unknown')} | **Language:** {res.get('language', 'unknown')}\n" - formatted += f"**Lines:** {res.get('line_start', 0)}-{res.get('line_end', 0)}\n\n" - formatted += f"```{res.get('language', 'python')}\n{res.get('code', '')}\n```\n\n" - else: - formatted += "No results found.\n" - - return [types.TextContent(type="text", text=formatted)] - - elif name == "list_repositories": - response = await client.get( - f"{BACKEND_API_URL}/repos", - headers=headers - ) - response.raise_for_status() - result = response.json() - - repo_list = "# Indexed Repositories\n\n" - if result.get("repositories"): - for repo in result["repositories"]: - repo_list += f"### {repo.get('name', 'unknown')}\n" - repo_list += f"- **ID:** `{repo.get('id')}`\n" - repo_list += f"- **Status:** {repo.get('status', 'unknown')}\n" - repo_list += f"- **Functions:** {repo.get('file_count', 0):,}\n" - repo_list += f"- **Branch:** {repo.get('branch', 'main')}\n\n" - else: - repo_list += "No repositories indexed yet.\n" - - return [types.TextContent(type="text", text=repo_list)] - - elif name == "get_dependency_graph": - response = await client.get( - f"{BACKEND_API_URL}/repos/{arguments['repo_id']}/dependencies", - headers=headers - ) - response.raise_for_status() - result = response.json() - - nodes = result.get('nodes', []) - edges = result.get('edges', []) - metrics = result.get('metrics', {}) - - formatted = f"# Dependency Graph Analysis\n\n" - formatted += f"**Total Files:** {len(nodes)}\n" - formatted += f"**Total Dependencies:** {metrics.get('total_edges', len(edges))}\n" - formatted += f"**Avg Dependencies per File:** {metrics.get('avg_dependencies', 0):.1f}\n\n" - - # Find most imported files (most dependents) - dependent_count = {} - for edge in edges: - target = edge.get('target', '') - dependent_count[target] = dependent_count.get(target, 0) + 1 - - if dependent_count: - sorted_deps = sorted(dependent_count.items(), key=lambda x: x[1], reverse=True)[:5] - formatted += "## Most Critical Files (High Impact)\n\n" - for file, count in sorted_deps: - formatted += f"- `{file}` - **{count} dependents**\n" - formatted += "\n" - - # Show high-import files - high_import_files = [n for n in nodes if n.get('imports', 0) >= 3] - if high_import_files: - formatted += "## Files with Most Imports\n\n" - for f in sorted(high_import_files, key=lambda x: x.get('imports', 0), reverse=True)[:5]: - formatted += f"- `{f['id']}` - imports {f['imports']} files\n" - - return [types.TextContent(type="text", text=formatted)] - - elif name == "analyze_code_style": - response = await client.get( - f"{BACKEND_API_URL}/repos/{arguments['repo_id']}/style-analysis", - headers=headers - ) - response.raise_for_status() - result = response.json() - - formatted = f"# Code Style Analysis\n\n" - - summary = result.get('summary', {}) - formatted += f"**Files Analyzed:** {summary.get('total_files_analyzed', 0)}\n" - formatted += f"**Functions:** {summary.get('total_functions', 0)}\n" - formatted += f"**Async Adoption:** {summary.get('async_adoption', '0%')}\n" - formatted += f"**Type Hints:** {summary.get('type_hints_usage', '0%')}\n\n" - - # Naming conventions - if result.get('naming_conventions', {}).get('functions'): - formatted += "## Function Naming Conventions\n\n" - for conv, info in result['naming_conventions']['functions'].items(): - formatted += f"- **{conv}:** {info['percentage']} ({info['count']} functions)\n" - formatted += "\n" - - # Top imports - if result.get('top_imports'): - formatted += "## Most Common Imports\n\n" - for item in result['top_imports'][:10]: - formatted += f"- `{item['module']}` (used {item['count']}×)\n" - - return [types.TextContent(type="text", text=formatted)] - - elif name == "analyze_impact": - response = await client.post( - f"{BACKEND_API_URL}/repos/{arguments['repo_id']}/impact", - json={"repo_id": arguments['repo_id'], "file_path": arguments['file_path']}, - headers=headers - ) - response.raise_for_status() - result = response.json() - - formatted = f"# Impact Analysis: {result.get('file', 'unknown')}\n\n" - formatted += f"**Risk Level:** {result.get('risk_level', 'unknown').upper()}\n" - formatted += f"**Impact Summary:** {result.get('impact_summary', '')}\n\n" - - formatted += f"## Dependencies ({len(result.get('direct_dependencies', []))})\n" - formatted += "Files this file imports:\n" - for dep in result.get('direct_dependencies', [])[:10]: - formatted += f"- `{dep}`\n" - formatted += "\n" - - formatted += f"## Dependents ({len(result.get('all_dependents', []))})\n" - formatted += "Files that would be affected by changes:\n" - for dep in result.get('all_dependents', [])[:15]: - formatted += f"- `{dep}`\n" - - if result.get('test_files'): - formatted += f"\n## Related Tests\n" - for test in result['test_files']: - formatted += f"- `{test}`\n" - - return [types.TextContent(type="text", text=formatted)] - - elif name == "get_repository_insights": - response = await client.get( - f"{BACKEND_API_URL}/repos/{arguments['repo_id']}/insights", - headers=headers - ) - response.raise_for_status() - result = response.json() - - formatted = f"# Repository Insights: {result.get('name', 'unknown')}\n\n" - formatted += f"**Status:** {result.get('status', 'unknown')}\n" - formatted += f"**Functions Indexed:** {result.get('functions_indexed', 0):,}\n" - formatted += f"**Total Files:** {result.get('total_files', 0)}\n" - formatted += f"**Total Dependencies:** {result.get('total_dependencies', 0)}\n\n" - - metrics = result.get('graph_metrics', {}) - if metrics.get('most_critical_files'): - formatted += "## Most Critical Files\n" - for item in metrics['most_critical_files'][:5]: - formatted += f"- `{item['file']}` ({item['dependents']} dependents)\n" - - return [types.TextContent(type="text", text=formatted)] - - elif name == "get_codebase_dna": - response = await client.get( - f"{BACKEND_API_URL}/repos/{arguments['repo_id']}/dna?format=markdown", - headers=headers - ) - response.raise_for_status() - result = response.json() - - # DNA is already formatted as markdown by the backend - dna_markdown = result.get('dna', '') - - formatted = "# Codebase DNA\n\n" - formatted += "**Use this information to write code that matches the existing patterns.**\n\n" - - if result.get('cached'): - formatted += "_(⚡ cached)_\n\n" - - formatted += dna_markdown - - formatted += "\n---\n" - formatted += "**Instructions:** When generating code for this codebase:\n" - formatted += "1. Follow the auth patterns shown above\n" - formatted += "2. Use the service layer structure (singletons in dependencies.py)\n" - formatted += "3. Match the database conventions (ID types, timestamps, RLS)\n" - formatted += "4. Use the logging patterns shown\n" - formatted += "5. Follow the naming conventions\n" - - return [types.TextContent(type="text", text=formatted)] - - else: - raise ValueError(f"Unknown tool: {name}") - - except httpx.HTTPError as e: - error_msg = f"API Error: {str(e)}" - return [types.TextContent(type="text", text=error_msg)] - except Exception as e: - error_msg = f"Error executing tool: {str(e)}" - return [types.TextContent(type="text", text=error_msg)] + """Dispatch tool calls to the handler layer.""" + return await call_tool(name, arguments) -async def main(): - """Run the MCP server""" - from mcp.server.models import ServerCapabilities - - async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - await server.run( - read_stream, - write_stream, - InitializationOptions( - server_name="codeintel-mcp", - server_version="0.3.0", - capabilities=ServerCapabilities( - tools={} +async def main() -> None: + """Run the MCP server over stdio transport.""" + try: + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name=SERVER_NAME, + server_version=SERVER_VERSION, + capabilities=ServerCapabilities(tools={}), ), - ), - ) + ) + finally: + await close_client() if __name__ == "__main__": diff --git a/mcp-server/tests/__init__.py b/mcp-server/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcp-server/tests/conftest.py b/mcp-server/tests/conftest.py new file mode 100644 index 0000000..360d7b8 --- /dev/null +++ b/mcp-server/tests/conftest.py @@ -0,0 +1,8 @@ +"""Shared test configuration. + +Adds the mcp-server root to sys.path so tests can import modules directly. +""" +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/mcp-server/tests/test_config.py b/mcp-server/tests/test_config.py new file mode 100644 index 0000000..a6176a6 --- /dev/null +++ b/mcp-server/tests/test_config.py @@ -0,0 +1,23 @@ +"""Tests for MCP server configuration.""" +import pytest + +from config import API_PREFIX, SERVER_NAME, SERVER_VERSION, BACKEND_API_URL + + +class TestConfig: + def test_api_prefix_format(self): + """API prefix must match /api/v{n} pattern.""" + assert API_PREFIX.startswith("/api/v") + + def test_server_name(self): + assert SERVER_NAME == "codeintel-mcp" + + def test_server_version_semver(self): + """Version should be semver-like (x.y.z).""" + parts = SERVER_VERSION.split(".") + assert len(parts) == 3 + assert all(p.isdigit() for p in parts) + + def test_backend_url_has_prefix(self): + """BACKEND_API_URL should include the API prefix for direct use.""" + assert BACKEND_API_URL.endswith(API_PREFIX) diff --git a/mcp-server/tests/test_formatters.py b/mcp-server/tests/test_formatters.py new file mode 100644 index 0000000..8e3ff70 --- /dev/null +++ b/mcp-server/tests/test_formatters.py @@ -0,0 +1,199 @@ +"""Tests for response formatters. + +Formatters are pure functions (dict -> str) so they're straightforward to test +without any mocking or network calls. +""" +import pytest + +from formatters import ( + format_search_results, + format_repositories, + format_dependency_graph, + format_code_style, + format_impact_analysis, + format_repository_insights, + format_codebase_dna, +) + + +# -- Search results (v2 format) -- + +class TestFormatSearchResults: + def test_empty_results(self): + result = {"total": 0, "results": [], "cached": False, "search_version": "v2"} + output = format_search_results(result) + assert "Found 0 results" in output + assert "No results found" in output + + def test_basic_result(self): + result = { + "total": 1, + "cached": False, + "search_version": "v2", + "results": [{ + "name": "authenticate", + "qualified_name": "auth.service.authenticate", + "file_path": "backend/auth.py", + "code": "def authenticate(token): ...", + "signature": "def authenticate(token: str) -> User", + "language": "python", + "score": 0.95, + "line_start": 10, + "line_end": 20, + "match_reason": "Semantic match on authentication logic", + }], + } + output = format_search_results(result) + assert "authenticate" in output + assert "95% match" in output + assert "backend/auth.py" in output + assert "auth.service.authenticate" in output + assert "Signature" in output + assert "Why:" in output + assert "(v2)" in output + + def test_cached_flag(self): + result = {"total": 0, "results": [], "cached": True, "search_version": "v2"} + output = format_search_results(result) + assert "(cached)" in output + + def test_v1_fallback(self): + """Formatter handles v1-style response with 'count' field.""" + result = {"count": 0, "results": []} + output = format_search_results(result) + assert "Found 0 results" in output + assert "(v1)" in output + + def test_none_score_handled(self): + """score=None should not crash the formatter.""" + result = {"total": 1, "cached": False, "search_version": "v2", "results": [{ + "name": "test", "file_path": "t.py", "code": "pass", + "language": "python", "score": None, "line_start": 1, "line_end": 1, + }]} + output = format_search_results(result) + assert "0% match" in output + + def test_no_emoji_in_output(self): + """CLAUDE.md violation check: no emojis anywhere in formatted output.""" + result = {"total": 1, "cached": True, "search_version": "v2", "results": [{ + "name": "test", "file_path": "test.py", "code": "pass", + "language": "python", "score": 0.5, "line_start": 1, "line_end": 1, + }]} + output = format_search_results(result) + # Lightning bolt was the specific emoji found in OPE-91 audit + assert "\u26a1" not in output + + +# -- Repositories -- + +class TestFormatRepositories: + def test_no_repos(self): + output = format_repositories({"repositories": []}) + assert "No repositories indexed" in output + + def test_repo_listing(self): + output = format_repositories({ + "repositories": [{ + "id": "abc-123", + "name": "my-project", + "status": "indexed", + "file_count": 1500, + "branch": "main", + }] + }) + assert "my-project" in output + assert "abc-123" in output + assert "indexed" in output + assert "1,500" in output + + +# -- Dependency graph -- + +class TestFormatDependencyGraph: + def test_empty_graph(self): + output = format_dependency_graph({"nodes": [], "edges": [], "metrics": {}}) + assert "Total Files:** 0" in output + + def test_critical_files_ranked(self): + output = format_dependency_graph({ + "nodes": [{"id": "a.py"}, {"id": "b.py"}], + "edges": [ + {"source": "b.py", "target": "a.py"}, + {"source": "c.py", "target": "a.py"}, + ], + "metrics": {"total_edges": 2, "avg_dependencies": 1.0}, + }) + assert "a.py" in output + assert "2 dependents" in output + + +# -- Code style -- + +class TestFormatCodeStyle: + def test_basic_summary(self): + output = format_code_style({ + "summary": { + "total_files_analyzed": 50, + "total_functions": 200, + "async_adoption": "35%", + "type_hints_usage": "80%", + }, + }) + assert "50" in output + assert "200" in output + assert "35%" in output + assert "80%" in output + + +# -- Impact analysis -- + +class TestFormatImpactAnalysis: + def test_high_risk(self): + output = format_impact_analysis({ + "file": "core/engine.py", + "risk_level": "high", + "impact_summary": "Central dependency", + "direct_dependencies": ["utils.py"], + "all_dependents": ["api.py", "cli.py"], + "test_files": ["test_engine.py"], + }) + assert "core/engine.py" in output + assert "HIGH" in output + assert "utils.py" in output + assert "test_engine.py" in output + + +# -- Repository insights -- + +class TestFormatRepositoryInsights: + def test_basic_insights(self): + output = format_repository_insights({ + "name": "opencodeintel", + "status": "indexed", + "functions_indexed": 500, + "total_files": 80, + "total_dependencies": 120, + }) + assert "opencodeintel" in output + assert "500" in output + + +# -- Codebase DNA -- + +class TestFormatCodebaseDna: + def test_dna_output(self): + output = format_codebase_dna({ + "dna": "## Patterns\n- Uses FastAPI\n- SQLAlchemy ORM", + "cached": False, + }) + assert "Codebase DNA" in output + assert "FastAPI" in output + assert "Follow the auth patterns" in output + + def test_dna_cached(self): + output = format_codebase_dna({"dna": "test", "cached": True}) + assert "(cached)" in output + + def test_no_emoji_in_dna(self): + output = format_codebase_dna({"dna": "test", "cached": True}) + assert "\u26a1" not in output diff --git a/mcp-server/tests/test_handlers.py b/mcp-server/tests/test_handlers.py new file mode 100644 index 0000000..63648ea --- /dev/null +++ b/mcp-server/tests/test_handlers.py @@ -0,0 +1,134 @@ +"""Tests for tool handler dispatch. + +Handlers call the API client, so we mock api_get/api_post to test +dispatch logic and error handling without network calls. +""" +import pytest +from unittest.mock import AsyncMock, patch +import httpx +import mcp.types as types + +from handlers import call_tool, _safe_error_message, _clamp_max_results + + +# -- Dispatch -- + +class TestCallTool: + @pytest.mark.asyncio + async def test_unknown_tool(self): + result = await call_tool("nonexistent_tool", {}) + assert len(result) == 1 + assert "Unknown tool" in result[0].text + + @pytest.mark.asyncio + @patch("handlers.api_post", new_callable=AsyncMock) + async def test_search_dispatches_to_v2(self, mock_post): + """Search handler should call /search/v2, not /search.""" + mock_post.return_value = { + "total": 0, "results": [], "cached": False, "search_version": "v2" + } + await call_tool("search_code", {"query": "test", "repo_id": "abc"}) + mock_post.assert_called_once() + call_path = mock_post.call_args[0][0] + assert call_path == "/search/v2" + + @pytest.mark.asyncio + @patch("handlers.api_post", new_callable=AsyncMock) + async def test_search_maps_max_results_to_top_k(self, mock_post): + """Tool schema uses max_results, v2 API expects top_k.""" + mock_post.return_value = { + "total": 0, "results": [], "cached": False, "search_version": "v2" + } + await call_tool("search_code", { + "query": "auth", "repo_id": "abc", "max_results": 5 + }) + payload = mock_post.call_args[1]["json"] + assert payload["top_k"] == 5 + assert "max_results" not in payload + + @pytest.mark.asyncio + @patch("handlers.api_get", new_callable=AsyncMock) + async def test_list_repos(self, mock_get): + mock_get.return_value = {"repositories": []} + result = await call_tool("list_repositories", {}) + assert len(result) == 1 + assert "No repositories indexed" in result[0].text + + @pytest.mark.asyncio + @patch("handlers.api_get", new_callable=AsyncMock) + async def test_dna_calls_correct_endpoint(self, mock_get): + mock_get.return_value = {"dna": "test patterns", "cached": False} + await call_tool("get_codebase_dna", {"repo_id": "r1"}) + call_path = mock_get.call_args[0][0] + assert "/repos/r1/dna" in call_path + + @pytest.mark.asyncio + @patch("handlers.api_get", new_callable=AsyncMock) + async def test_none_arguments_handled(self, mock_get): + """call_tool(name, None) should not crash.""" + mock_get.return_value = {"repositories": []} + result = await call_tool("list_repositories", None) + assert len(result) == 1 + assert "No repositories indexed" in result[0].text + + +# -- Input validation -- + +class TestClampMaxResults: + def test_default_on_none(self): + assert _clamp_max_results(None) == 10 + + def test_default_on_string(self): + assert _clamp_max_results("abc") == 10 + + def test_clamps_zero_to_one(self): + assert _clamp_max_results(0) == 1 + + def test_clamps_negative(self): + assert _clamp_max_results(-5) == 1 + + def test_clamps_over_max(self): + assert _clamp_max_results(500) == 100 + + def test_valid_value_passes(self): + assert _clamp_max_results(25) == 25 + + +# -- Error handling -- + +class TestSafeErrorMessage: + def test_http_status_error(self): + response = httpx.Response(403, request=httpx.Request("GET", "http://x")) + error = httpx.HTTPStatusError("forbidden", request=response.request, response=response) + msg = _safe_error_message("search_code", {"repo_id": "abc"}, error) + assert "403" in msg + assert "search_code" in msg + assert "abc" in msg + # Should NOT leak internal details like URLs or stack traces + assert "http://" not in msg + + def test_timeout_error(self): + error = httpx.TimeoutException("timed out") + msg = _safe_error_message("get_codebase_dna", {"repo_id": "r1"}, error) + assert "timed out" in msg.lower() + assert "get_codebase_dna" in msg + + def test_connect_error(self): + error = httpx.ConnectError("connection refused") + msg = _safe_error_message("search_code", {}, error) + assert "Cannot connect" in msg + + def test_value_error_sanitized(self): + """ValueError should not leak internal details.""" + error = ValueError("No API_KEY configured") + msg = _safe_error_message("search_code", {"repo_id": "r1"}, error) + assert "Tool input error" in msg + assert "search_code" in msg + # Internal message should NOT be in output + assert "No API_KEY" not in msg + + def test_generic_error_hides_details(self): + error = RuntimeError("internal traceback info") + msg = _safe_error_message("search_code", {"repo_id": "r1"}, error) + assert "internal traceback" not in msg + assert "Unexpected error" in msg diff --git a/mcp-server/tests/test_tools.py b/mcp-server/tests/test_tools.py new file mode 100644 index 0000000..3be62c1 --- /dev/null +++ b/mcp-server/tests/test_tools.py @@ -0,0 +1,72 @@ +"""Tests for tool schema definitions. + +Validates that tool schemas follow MCP protocol requirements +and that adding/removing tools doesn't break the schema contract. +""" +import pytest + +from tools import get_tool_schemas + +EXPECTED_TOOLS = { + "search_code", + "list_repositories", + "get_dependency_graph", + "analyze_code_style", + "analyze_impact", + "get_repository_insights", + "get_codebase_dna", +} + + +class TestToolSchemas: + def test_all_tools_registered(self): + schemas = get_tool_schemas() + names = {t.name for t in schemas} + assert names == EXPECTED_TOOLS + + def test_no_duplicate_names(self): + schemas = get_tool_schemas() + names = [t.name for t in schemas] + assert len(names) == len(set(names)) + + def test_all_have_descriptions(self): + """Every tool needs a description -- LLMs use it to decide when to call.""" + for tool in get_tool_schemas(): + assert tool.description, f"{tool.name} has no description" + assert len(tool.description) > 20, f"{tool.name} description too short" + + def test_all_have_valid_input_schema(self): + for tool in get_tool_schemas(): + schema = tool.inputSchema + assert schema.get("type") == "object", f"{tool.name} schema not object" + assert "properties" in schema, f"{tool.name} missing properties" + + def test_search_requires_query_and_repo(self): + schemas = {t.name: t for t in get_tool_schemas()} + search = schemas["search_code"] + assert "query" in search.inputSchema["required"] + assert "repo_id" in search.inputSchema["required"] + + def test_repo_tools_require_repo_id(self): + """Tools that operate on a repo should require repo_id.""" + schemas = {t.name: t for t in get_tool_schemas()} + repo_tools = [ + "get_dependency_graph", "analyze_code_style", + "analyze_impact", "get_repository_insights", "get_codebase_dna", + ] + for name in repo_tools: + required = schemas[name].inputSchema.get("required", []) + assert "repo_id" in required, f"{name} should require repo_id" + + def test_search_max_results_bounded(self): + """max_results schema should have min/max to prevent invalid searches.""" + schemas = {t.name: t for t in get_tool_schemas()} + max_results = schemas["search_code"].inputSchema["properties"]["max_results"] + assert max_results["type"] == "integer" + assert max_results["minimum"] >= 1 + assert max_results["maximum"] > max_results["minimum"] + + def test_list_repos_has_no_required_fields(self): + schemas = {t.name: t for t in get_tool_schemas()} + list_repos = schemas["list_repositories"] + assert "required" not in list_repos.inputSchema diff --git a/mcp-server/tools.py b/mcp-server/tools.py new file mode 100644 index 0000000..20091f1 --- /dev/null +++ b/mcp-server/tools.py @@ -0,0 +1,148 @@ +"""MCP tool schema definitions. + +Each tool has a name, description, and JSON Schema for its input. +Descriptions are optimized for LLM consumption -- they tell the model +WHEN and WHY to use each tool, not just what it does. +""" +import mcp.types as types + + +def get_tool_schemas() -> list[types.Tool]: + """Return all available tool definitions.""" + return [ + types.Tool( + name="search_code", + description=( + "Semantically search code in a repository. Finds code by meaning, " + "not just keywords. Use this to find existing implementations, " + "patterns, or specific functionality." + ), + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + "Search query (natural language or code snippet). " + "Examples: 'authentication middleware', " + "'React hook for state', 'database connection pool'" + ), + }, + "repo_id": { + "type": "string", + "description": "Repository identifier", + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results (default: 10)", + "default": 10, + "minimum": 1, + "maximum": 100, + }, + }, + "required": ["query", "repo_id"], + }, + ), + types.Tool( + name="list_repositories", + description="List all indexed repositories available for analysis", + inputSchema={"type": "object", "properties": {}}, + ), + types.Tool( + name="get_dependency_graph", + description=( + "Get the complete dependency graph for a repository. Shows which " + "files depend on which, identifies critical files, and reveals " + "architecture patterns." + ), + inputSchema={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Repository identifier", + } + }, + "required": ["repo_id"], + }, + ), + types.Tool( + name="analyze_code_style", + description=( + "Analyze team coding patterns and conventions. Returns naming " + "conventions, async usage, type hint usage, common imports, and " + "coding patterns. Use this to match team style when generating code." + ), + inputSchema={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Repository identifier", + } + }, + "required": ["repo_id"], + }, + ), + types.Tool( + name="analyze_impact", + description=( + "Analyze the impact of changing a specific file. Shows what files " + "depend on it, what it depends on, risk level, and related test " + "files. Critical for understanding change consequences." + ), + inputSchema={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Repository identifier", + }, + "file_path": { + "type": "string", + "description": "Path to the file to analyze (relative to repo root)", + }, + }, + "required": ["repo_id", "file_path"], + }, + ), + types.Tool( + name="get_repository_insights", + description=( + "Get comprehensive insights about a repository including dependency " + "metrics, code style summary, and architecture overview. Use this " + "for high-level codebase understanding." + ), + inputSchema={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Repository identifier", + } + }, + "required": ["repo_id"], + }, + ), + types.Tool( + name="get_codebase_dna", + description=( + "Extract the architectural DNA of a codebase. Returns patterns, " + "conventions, and constraints that define how code should be written. " + "Use this BEFORE generating any code to understand: auth patterns, " + "service layer structure, database conventions, error handling, " + "logging patterns, naming conventions, and common imports. This " + "ensures generated code matches existing architecture." + ), + inputSchema={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Repository identifier", + } + }, + "required": ["repo_id"], + }, + ), + ]