diff --git a/backend/services/indexer_optimized.py b/backend/services/indexer_optimized.py index 2509e41..2ecd8b7 100644 --- a/backend/services/indexer_optimized.py +++ b/backend/services/indexer_optimized.py @@ -361,19 +361,21 @@ def extract_functions_v2(self, repo_path: str, max_functions: int = 5000) -> Lis logger.info("V2 extraction", total=len(raw), kept=len(filtered)) return filtered - def _build_embedding_text(self, func: ExtractedFunction) -> str: + def _build_embedding_text(self, func: ExtractedFunction, summary: str = "") -> str: """Build rich text for embedding.""" parts = [ f"Function: {func.qualified_name}", f"Signature: {func.signature}", ] + if summary: + parts.append(f"Summary: {summary}") if func.docstring: parts.append(f"Description: {func.docstring[:500]}") parts.append(f"Language: {func.language}") parts.append(f"Code:\n{func.code[:2000]}") return "\n".join(parts) - def _build_metadata(self, func: ExtractedFunction, repo_id: str) -> Dict: + def _build_metadata(self, func: ExtractedFunction, repo_id: str, summary: str = "") -> Dict: """Build Pinecone metadata from function.""" return { "repo_id": repo_id, @@ -389,12 +391,21 @@ def _build_metadata(self, func: ExtractedFunction, repo_id: str) -> Dict: "class_name": func.class_name or "", "docstring": (func.docstring or "")[:500], "is_async": func.is_async, + "summary": summary, } - async def index_repository_v2(self, repo_id: str, repo_path: str, progress_callback=None) -> int: + async def index_repository_v2( + self, + repo_id: str, + repo_path: str, + progress_callback=None, + generate_summaries: bool = False + ) -> int: """Index repository using V2 function-level extraction.""" + from services.search_v2 import generate_summaries as gen_summaries + start_time = time.time() - logger.info("V2 indexing started", repo_id=repo_id) + logger.info("V2 indexing started", repo_id=repo_id, with_summaries=generate_summaries) functions = self.extract_functions_v2(repo_path) if not functions: @@ -402,8 +413,14 @@ async def index_repository_v2(self, repo_id: str, repo_path: str, progress_callb await progress_callback(0, 0, 0) return 0 + # generate summaries if requested + summaries = [""] * len(functions) + if generate_summaries: + logger.info("Generating summaries", count=len(functions)) + summaries = await gen_summaries(functions, batch_size=10) + # generate embeddings - texts = [self._build_embedding_text(f) for f in functions] + texts = [self._build_embedding_text(f, s) for f, s in zip(functions, summaries)] embeddings = [] for i in range(0, len(texts), self.EMBEDDING_BATCH_SIZE): @@ -418,9 +435,9 @@ async def index_repository_v2(self, repo_id: str, repo_path: str, progress_callb { "id": hashlib.md5(func.id_string.encode()).hexdigest(), "values": emb, - "metadata": self._build_metadata(func, repo_id) + "metadata": self._build_metadata(func, repo_id, summary) } - for func, emb in zip(functions, embeddings) + for func, emb, summary in zip(functions, embeddings, summaries) ] # upsert to pinecone diff --git a/backend/services/search_v2/__init__.py b/backend/services/search_v2/__init__.py index d5e4de0..948b235 100644 --- a/backend/services/search_v2/__init__.py +++ b/backend/services/search_v2/__init__.py @@ -2,6 +2,7 @@ from services.search_v2.types import ExtractedFunction, SearchResult, Language from services.search_v2.tree_sitter_extractor import TreeSitterExtractor from services.search_v2.function_filter import FunctionFilter, filter_functions +from services.search_v2.summary_generator import SummaryGenerator, generate_summaries __all__ = [ "ExtractedFunction", @@ -10,4 +11,6 @@ "TreeSitterExtractor", "FunctionFilter", "filter_functions", + "SummaryGenerator", + "generate_summaries", ] diff --git a/backend/services/search_v2/summary_generator.py b/backend/services/search_v2/summary_generator.py new file mode 100644 index 0000000..547aed2 --- /dev/null +++ b/backend/services/search_v2/summary_generator.py @@ -0,0 +1,132 @@ +"""Generate concise summaries for functions using GPT-4o-mini.""" +import asyncio +from typing import List, Optional +from openai import AsyncOpenAI + +from services.search_v2.types import ExtractedFunction +from services.observability import logger + +SUMMARY_PROMPT = """Summarize what this function does in one sentence (max 15 words). +Focus on the purpose, not implementation details. Be specific. + +Function: {name} +Signature: {signature} +Code: +```{language} +{code} +``` + +Summary:""" + +BATCH_PROMPT = """For each function below, write a one-sentence summary (max 15 words each). +Focus on purpose, not implementation. Be specific. Return one summary per line. + +{functions} + +Summaries (one per line):""" + + +class SummaryGenerator: + """Generate function summaries using GPT-4o-mini.""" + + def __init__(self, model: str = "gpt-4o-mini", batch_size: int = 10): + self.client = AsyncOpenAI() + self.model = model + self.batch_size = batch_size + + async def generate_single(self, func: ExtractedFunction) -> str: + """Generate summary for a single function.""" + prompt = SUMMARY_PROMPT.format( + name=func.qualified_name, + signature=func.signature, + language=func.language, + code=func.code[:1500], + ) + + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=50, + temperature=0.3, + ) + return response.choices[0].message.content.strip() + except Exception as e: + logger.warning("Summary generation failed", func=func.name, error=str(e)) + return "" + + async def generate_batch(self, functions: List[ExtractedFunction]) -> List[str]: + """Generate summaries for multiple functions in one API call.""" + if not functions: + return [] + + func_texts = [] + for i, func in enumerate(functions, 1): + func_texts.append( + f"{i}. {func.qualified_name}\n" + f" Signature: {func.signature}\n" + f" Code: {func.code[:800]}\n" + ) + + prompt = BATCH_PROMPT.format(functions="\n".join(func_texts)) + + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=50 * len(functions), + temperature=0.3, + ) + + lines = response.choices[0].message.content.strip().split("\n") + summaries = [] + + for line in lines: + line = line.strip() + if not line: + continue + # strip leading number if present (e.g., "1. Summary here") + if line[0].isdigit() and ". " in line[:4]: + line = line.split(". ", 1)[1] + summaries.append(line) + + # pad with empty strings if we got fewer summaries + while len(summaries) < len(functions): + summaries.append("") + + return summaries[:len(functions)] + + except Exception as e: + logger.warning("Batch summary failed", count=len(functions), error=str(e)) + return [""] * len(functions) + + async def generate_all( + self, + functions: List[ExtractedFunction], + progress_callback=None + ) -> List[str]: + """Generate summaries for all functions with batching.""" + all_summaries = [] + total = len(functions) + + for i in range(0, total, self.batch_size): + batch = functions[i:i + self.batch_size] + summaries = await self.generate_batch(batch) + all_summaries.extend(summaries) + + if progress_callback: + await progress_callback(len(all_summaries), total) + + logger.debug("Summaries generated", progress=len(all_summaries), total=total) + + return all_summaries + + +async def generate_summaries( + functions: List[ExtractedFunction], + batch_size: int = 10, + progress_callback=None +) -> List[str]: + """Convenience function to generate summaries.""" + generator = SummaryGenerator(batch_size=batch_size) + return await generator.generate_all(functions, progress_callback) diff --git a/backend/tests/test_summary_generator.py b/backend/tests/test_summary_generator.py new file mode 100644 index 0000000..a80588d --- /dev/null +++ b/backend/tests/test_summary_generator.py @@ -0,0 +1,142 @@ +"""Tests for AI summary generation.""" +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from services.search_v2 import SummaryGenerator, ExtractedFunction, generate_summaries + + +def make_func(name: str, code: str = "def x(): pass") -> ExtractedFunction: + return ExtractedFunction( + name=name, + qualified_name=name, + file_path="test.py", + code=code, + signature=f"def {name}():", + language="python", + start_line=1, + end_line=2, + ) + + +class TestSummaryGenerator: + + @pytest.fixture + def generator(self): + return SummaryGenerator(batch_size=3) + + @pytest.mark.asyncio + async def test_generate_single(self, generator): + func = make_func("process_data", "def process_data(items): return [x*2 for x in items]") + + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Doubles each item in the input list."))] + + with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock: + mock.return_value = mock_response + summary = await generator.generate_single(func) + + assert summary == "Doubles each item in the input list." + mock.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_batch(self, generator): + funcs = [ + make_func("fetch_users"), + make_func("save_data"), + make_func("validate_input"), + ] + + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock( + content="1. Retrieves user records from database.\n2. Persists data to storage.\n3. Validates input parameters." + ))] + + with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock: + mock.return_value = mock_response + summaries = await generator.generate_batch(funcs) + + assert len(summaries) == 3 + assert "user" in summaries[0].lower() + assert "data" in summaries[1].lower() or "storage" in summaries[1].lower() + assert "valid" in summaries[2].lower() + + @pytest.mark.asyncio + async def test_generate_batch_handles_fewer_responses(self, generator): + funcs = [make_func("a"), make_func("b"), make_func("c")] + + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Only one summary"))] + + with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock: + mock.return_value = mock_response + summaries = await generator.generate_batch(funcs) + + assert len(summaries) == 3 + assert summaries[0] == "Only one summary" + assert summaries[1] == "" + assert summaries[2] == "" + + @pytest.mark.asyncio + async def test_generate_all_batches_correctly(self, generator): + funcs = [make_func(f"func_{i}") for i in range(7)] + + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock( + content="1. Summary one\n2. Summary two\n3. Summary three" + ))] + + with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock: + mock.return_value = mock_response + summaries = await generator.generate_all(funcs) + + # batch_size=3, so 7 funcs = 3 API calls + assert mock.call_count == 3 + assert len(summaries) == 7 + + @pytest.mark.asyncio + async def test_generate_single_handles_error(self, generator): + func = make_func("broken") + + with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock: + mock.side_effect = Exception("API error") + summary = await generator.generate_single(func) + + assert summary == "" + + @pytest.mark.asyncio + async def test_generate_batch_handles_error(self, generator): + funcs = [make_func("a"), make_func("b")] + + with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock: + mock.side_effect = Exception("API error") + summaries = await generator.generate_batch(funcs) + + assert summaries == ["", ""] + + @pytest.mark.asyncio + async def test_empty_input(self, generator): + summaries = await generator.generate_batch([]) + assert summaries == [] + + summaries = await generator.generate_all([]) + assert summaries == [] + + +class TestConvenienceFunction: + + @pytest.mark.asyncio + async def test_generate_summaries_convenience(self): + funcs = [make_func("test_func")] + + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Test summary"))] + + with patch('services.search_v2.summary_generator.AsyncOpenAI') as MockClient: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_client + + summaries = await generate_summaries(funcs, batch_size=5) + + assert len(summaries) == 1 + assert summaries[0] == "Test summary"