Skip to content

Commit dc70c70

Browse files
committed
fix: address PR review -- defensive formatting, input validation, error sanitization
- api_client: fix docstring (warns -> raises), add asyncio.Lock for safe lazy init - formatters: guard against None/non-numeric score, use .get() for all dict access - handlers: sanitize ValueError (no longer leaks internals), log errors server-side - handlers: add _clamp_max_results to validate and bound top_k to [1, 100] - tools: add minimum/maximum constraints to max_results schema - tests: mock api_get in test_none_arguments_handled (no real HTTP) - tests: add 8 new tests (score=None, clamp bounds, ValueError sanitized, schema bounds) - tests: extract sys.path to conftest.py, remove from all test files - .env.example: alphabetical key ordering (dotenv-linter compliance) Skipped: N3 (env var rename -- deployment concern), N5 (emoji regex -- overkill) 45 tests pass.
1 parent dfe732d commit dc70c70

10 files changed

Lines changed: 102 additions & 40 deletions

File tree

mcp-server/.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Backend API Configuration
2-
BACKEND_API_URL=http://localhost:8000
32
API_KEY=your-api-key-here
3+
BACKEND_API_URL=http://localhost:8000

mcp-server/api_client.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
33
Uses a module-level client to avoid creating new TCP connections per tool call.
44
The client is initialized lazily on first use and reused for all subsequent calls.
5+
Concurrent access is serialized via asyncio.Lock to prevent duplicate clients.
56
"""
7+
import asyncio
68
from typing import Any, Optional
79

810
import httpx
@@ -12,10 +14,14 @@
1214

1315
# Persistent client reused across all tool calls
1416
_client: Optional[httpx.AsyncClient] = None
17+
_client_lock: asyncio.Lock = asyncio.Lock()
1518

1619

1720
def _get_headers() -> dict[str, str]:
18-
"""Build auth headers. Warns if no API key is configured."""
21+
"""Return Authorization header with the configured API_KEY.
22+
23+
Raises ValueError if API_KEY is empty or unset.
24+
"""
1925
if not API_KEY:
2026
raise ValueError(
2127
"No API_KEY configured. Set API_KEY in .env or environment."
@@ -26,12 +32,13 @@ def _get_headers() -> dict[str, str]:
2632
async def get_client() -> httpx.AsyncClient:
2733
"""Get or create the persistent HTTP client."""
2834
global _client
29-
if _client is None or _client.is_closed:
30-
_client = httpx.AsyncClient(
31-
base_url=BACKEND_API_URL,
32-
timeout=120.0,
33-
headers=_get_headers(),
34-
)
35+
async with _client_lock:
36+
if _client is None or _client.is_closed:
37+
_client = httpx.AsyncClient(
38+
base_url=BACKEND_API_URL,
39+
timeout=120.0,
40+
headers=_get_headers(),
41+
)
3542
return _client
3643

3744

mcp-server/formatters.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ def format_search_results(result: dict) -> str:
2020
return output + "No results found.\n"
2121

2222
for idx, res in enumerate(result["results"], 1):
23-
score = res.get("score", 0) * 100
23+
score_raw = res.get("score")
24+
try:
25+
score = float(score_raw) * 100
26+
except (TypeError, ValueError):
27+
score = 0
2428
name = res.get("name", "unknown")
2529
file_path = res.get("file_path", "unknown")
2630
lang = res.get("language", "unknown")
@@ -97,7 +101,7 @@ def format_dependency_graph(result: dict) -> str:
97101
if high_import:
98102
output += "## Files with Most Imports\n\n"
99103
for f in sorted(high_import, key=lambda x: x.get("imports", 0), reverse=True)[:5]:
100-
output += f"- `{f['id']}` - imports {f['imports']} files\n"
104+
output += f"- `{f.get('id', '<unknown>')}` - imports {f.get('imports', 0)} files\n"
101105

102106
return output
103107

@@ -115,14 +119,14 @@ def format_code_style(result: dict) -> str:
115119
if naming:
116120
output += "## Function Naming Conventions\n\n"
117121
for conv, info in naming.items():
118-
output += f"- **{conv}:** {info['percentage']} ({info['count']} functions)\n"
122+
output += f"- **{conv}:** {info.get('percentage', '?')} ({info.get('count', 0)} functions)\n"
119123
output += "\n"
120124

121125
top_imports = result.get("top_imports")
122126
if top_imports:
123127
output += "## Most Common Imports\n\n"
124128
for item in top_imports[:10]:
125-
output += f"- `{item['module']}` (used {item['count']}x)\n"
129+
output += f"- `{item.get('module', '<unknown>')}` (used {item.get('count', 0)}x)\n"
126130

127131
return output
128132

@@ -168,7 +172,7 @@ def format_repository_insights(result: dict) -> str:
168172
if critical:
169173
output += "## Most Critical Files\n"
170174
for item in critical[:5]:
171-
output += f"- `{item['file']}` ({item['dependents']} dependents)\n"
175+
output += f"- `{item.get('file', '<unknown>')}` ({item.get('dependents', 0)} dependents)\n"
172176

173177
return output
174178

mcp-server/handlers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
Each handler follows the same pattern: call API, format response.
55
Error handling is centralized in call_tool() so individual handlers stay clean.
66
"""
7+
import logging
78
from typing import Any
89

910
import httpx
1011
import mcp.types as types
1112

13+
logger = logging.getLogger(__name__)
14+
1215
from api_client import api_get, api_post
1316
from formatters import (
1417
format_codebase_dna,
@@ -21,12 +24,22 @@
2124
)
2225

2326

27+
def _clamp_max_results(raw: Any) -> int:
28+
"""Validate and clamp max_results to [1, 100]."""
29+
try:
30+
value = int(raw)
31+
except (TypeError, ValueError):
32+
return 10
33+
return max(1, min(value, 100))
34+
35+
2436
async def _handle_search(args: dict[str, Any]) -> str:
2537
# Map tool schema's max_results to v2 API's top_k
38+
top_k = _clamp_max_results(args.get("max_results", 10))
2639
payload = {
2740
"query": args["query"],
2841
"repo_id": args["repo_id"],
29-
"top_k": args.get("max_results", 10),
42+
"top_k": top_k,
3043
"use_reranking": True,
3144
}
3245
result = await api_post("/search/v2", json=payload)
@@ -89,7 +102,9 @@ def _safe_error_message(tool_name: str, args: dict[str, Any], error: Exception)
89102
if isinstance(error, httpx.ConnectError):
90103
return f"Cannot connect to backend for tool '{tool_name}'. Is the server running?"
91104
if isinstance(error, ValueError):
92-
return str(error)
105+
logger.warning("ValueError in tool '%s' (repo: %s): %s", tool_name, repo_id, error)
106+
return f"Tool input error for '{tool_name}' (repo: {repo_id})"
107+
logger.exception("Unexpected error in tool '%s' (repo: %s)", tool_name, repo_id)
93108
return f"Unexpected error in tool '{tool_name}' (repo: {repo_id})"
94109

95110

mcp-server/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Shared test configuration.
2+
3+
Adds the mcp-server root to sys.path so tests can import modules directly.
4+
"""
5+
import os
6+
import sys
7+
8+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

mcp-server/tests/test_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
"""Tests for MCP server configuration."""
22
import pytest
3-
import sys
4-
import os
5-
6-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
73

84
from config import API_PREFIX, SERVER_NAME, SERVER_VERSION, BACKEND_API_URL
95

mcp-server/tests/test_formatters.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
without any mocking or network calls.
55
"""
66
import pytest
7-
import sys
8-
import os
9-
10-
# Add parent directory to path so we can import the modules
11-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
127

138
from formatters import (
149
format_search_results,
@@ -69,6 +64,15 @@ def test_v1_fallback(self):
6964
assert "Found 0 results" in output
7065
assert "(v1)" in output
7166

67+
def test_none_score_handled(self):
68+
"""score=None should not crash the formatter."""
69+
result = {"total": 1, "cached": False, "search_version": "v2", "results": [{
70+
"name": "test", "file_path": "t.py", "code": "pass",
71+
"language": "python", "score": None, "line_start": 1, "line_end": 1,
72+
}]}
73+
output = format_search_results(result)
74+
assert "0% match" in output
75+
7276
def test_no_emoji_in_output(self):
7377
"""CLAUDE.md violation check: no emojis anywhere in formatted output."""
7478
result = {"total": 1, "cached": True, "search_version": "v2", "results": [{

mcp-server/tests/test_handlers.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,11 @@
44
dispatch logic and error handling without network calls.
55
"""
66
import pytest
7-
import sys
8-
import os
9-
10-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11-
127
from unittest.mock import AsyncMock, patch
138
import httpx
149
import mcp.types as types
1510

16-
from handlers import call_tool, _safe_error_message
11+
from handlers import call_tool, _safe_error_message, _clamp_max_results
1712

1813

1914
# -- Dispatch --
@@ -68,11 +63,35 @@ async def test_dna_calls_correct_endpoint(self, mock_get):
6863
assert "/repos/r1/dna" in call_path
6964

7065
@pytest.mark.asyncio
71-
async def test_none_arguments_handled(self):
66+
@patch("handlers.api_get", new_callable=AsyncMock)
67+
async def test_none_arguments_handled(self, mock_get):
7268
"""call_tool(name, None) should not crash."""
69+
mock_get.return_value = {"repositories": []}
7370
result = await call_tool("list_repositories", None)
74-
# Will fail on network, but should not crash on None args
7571
assert len(result) == 1
72+
assert "No repositories indexed" in result[0].text
73+
74+
75+
# -- Input validation --
76+
77+
class TestClampMaxResults:
78+
def test_default_on_none(self):
79+
assert _clamp_max_results(None) == 10
80+
81+
def test_default_on_string(self):
82+
assert _clamp_max_results("abc") == 10
83+
84+
def test_clamps_zero_to_one(self):
85+
assert _clamp_max_results(0) == 1
86+
87+
def test_clamps_negative(self):
88+
assert _clamp_max_results(-5) == 1
89+
90+
def test_clamps_over_max(self):
91+
assert _clamp_max_results(500) == 100
92+
93+
def test_valid_value_passes(self):
94+
assert _clamp_max_results(25) == 25
7695

7796

7897
# -- Error handling --
@@ -99,11 +118,14 @@ def test_connect_error(self):
99118
msg = _safe_error_message("search_code", {}, error)
100119
assert "Cannot connect" in msg
101120

102-
def test_value_error_passthrough(self):
103-
"""ValueError messages are user-facing (e.g. missing API key)."""
121+
def test_value_error_sanitized(self):
122+
"""ValueError should not leak internal details."""
104123
error = ValueError("No API_KEY configured")
105-
msg = _safe_error_message("search_code", {}, error)
106-
assert "No API_KEY configured" in msg
124+
msg = _safe_error_message("search_code", {"repo_id": "r1"}, error)
125+
assert "Tool input error" in msg
126+
assert "search_code" in msg
127+
# Internal message should NOT be in output
128+
assert "No API_KEY" not in msg
107129

108130
def test_generic_error_hides_details(self):
109131
error = RuntimeError("internal traceback info")

mcp-server/tests/test_tools.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
and that adding/removing tools doesn't break the schema contract.
55
"""
66
import pytest
7-
import sys
8-
import os
9-
10-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
117

128
from tools import get_tool_schemas
139

@@ -62,6 +58,14 @@ def test_repo_tools_require_repo_id(self):
6258
required = schemas[name].inputSchema.get("required", [])
6359
assert "repo_id" in required, f"{name} should require repo_id"
6460

61+
def test_search_max_results_bounded(self):
62+
"""max_results schema should have min/max to prevent invalid searches."""
63+
schemas = {t.name: t for t in get_tool_schemas()}
64+
max_results = schemas["search_code"].inputSchema["properties"]["max_results"]
65+
assert max_results["type"] == "integer"
66+
assert max_results["minimum"] >= 1
67+
assert max_results["maximum"] > max_results["minimum"]
68+
6569
def test_list_repos_has_no_required_fields(self):
6670
schemas = {t.name: t for t in get_tool_schemas()}
6771
list_repos = schemas["list_repositories"]

mcp-server/tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def get_tool_schemas() -> list[types.Tool]:
3636
"type": "integer",
3737
"description": "Maximum number of results (default: 10)",
3838
"default": 10,
39+
"minimum": 1,
40+
"maximum": 100,
3941
},
4042
},
4143
"required": ["query", "repo_id"],

0 commit comments

Comments
 (0)