Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions backend/services/indexer_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,21 @@ def extract_functions_v2(self, repo_path: str, max_functions: int = 5000) -> Lis
logger.info("V2 extraction", total=len(raw), kept=len(filtered))
return filtered

def _build_embedding_text(self, func: ExtractedFunction) -> str:
def _build_embedding_text(self, func: ExtractedFunction, summary: str = "") -> str:
"""Build rich text for embedding."""
parts = [
f"Function: {func.qualified_name}",
f"Signature: {func.signature}",
]
if summary:
parts.append(f"Summary: {summary}")
if func.docstring:
parts.append(f"Description: {func.docstring[:500]}")
parts.append(f"Language: {func.language}")
parts.append(f"Code:\n{func.code[:2000]}")
return "\n".join(parts)

def _build_metadata(self, func: ExtractedFunction, repo_id: str) -> Dict:
def _build_metadata(self, func: ExtractedFunction, repo_id: str, summary: str = "") -> Dict:
"""Build Pinecone metadata from function."""
return {
"repo_id": repo_id,
Expand All @@ -389,21 +391,36 @@ def _build_metadata(self, func: ExtractedFunction, repo_id: str) -> Dict:
"class_name": func.class_name or "",
"docstring": (func.docstring or "")[:500],
"is_async": func.is_async,
"summary": summary,
}

async def index_repository_v2(self, repo_id: str, repo_path: str, progress_callback=None) -> int:
async def index_repository_v2(
self,
repo_id: str,
repo_path: str,
progress_callback=None,
generate_summaries: bool = False
) -> int:
"""Index repository using V2 function-level extraction."""
from services.search_v2 import generate_summaries as gen_summaries

start_time = time.time()
logger.info("V2 indexing started", repo_id=repo_id)
logger.info("V2 indexing started", repo_id=repo_id, with_summaries=generate_summaries)

functions = self.extract_functions_v2(repo_path)
if not functions:
if progress_callback:
await progress_callback(0, 0, 0)
return 0

# generate summaries if requested
summaries = [""] * len(functions)
if generate_summaries:
logger.info("Generating summaries", count=len(functions))
summaries = await gen_summaries(functions, batch_size=10)

# generate embeddings
texts = [self._build_embedding_text(f) for f in functions]
texts = [self._build_embedding_text(f, s) for f, s in zip(functions, summaries)]
embeddings = []

for i in range(0, len(texts), self.EMBEDDING_BATCH_SIZE):
Expand All @@ -418,9 +435,9 @@ async def index_repository_v2(self, repo_id: str, repo_path: str, progress_callb
{
"id": hashlib.md5(func.id_string.encode()).hexdigest(),
"values": emb,
"metadata": self._build_metadata(func, repo_id)
"metadata": self._build_metadata(func, repo_id, summary)
}
for func, emb in zip(functions, embeddings)
for func, emb, summary in zip(functions, embeddings, summaries)
]

# upsert to pinecone
Expand Down
3 changes: 3 additions & 0 deletions backend/services/search_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from services.search_v2.types import ExtractedFunction, SearchResult, Language
from services.search_v2.tree_sitter_extractor import TreeSitterExtractor
from services.search_v2.function_filter import FunctionFilter, filter_functions
from services.search_v2.summary_generator import SummaryGenerator, generate_summaries

__all__ = [
"ExtractedFunction",
Expand All @@ -10,4 +11,6 @@
"TreeSitterExtractor",
"FunctionFilter",
"filter_functions",
"SummaryGenerator",
"generate_summaries",
]
132 changes: 132 additions & 0 deletions backend/services/search_v2/summary_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Generate concise summaries for functions using GPT-4o-mini."""
import asyncio
from typing import List, Optional
from openai import AsyncOpenAI

from services.search_v2.types import ExtractedFunction
from services.observability import logger

SUMMARY_PROMPT = """Summarize what this function does in one sentence (max 15 words).
Focus on the purpose, not implementation details. Be specific.

Function: {name}
Signature: {signature}
Code:
```{language}
{code}
```

Summary:"""

BATCH_PROMPT = """For each function below, write a one-sentence summary (max 15 words each).
Focus on purpose, not implementation. Be specific. Return one summary per line.

{functions}

Summaries (one per line):"""


class SummaryGenerator:
"""Generate function summaries using GPT-4o-mini."""

def __init__(self, model: str = "gpt-4o-mini", batch_size: int = 10):
self.client = AsyncOpenAI()
self.model = model
self.batch_size = batch_size

async def generate_single(self, func: ExtractedFunction) -> str:
"""Generate summary for a single function."""
prompt = SUMMARY_PROMPT.format(
name=func.qualified_name,
signature=func.signature,
language=func.language,
code=func.code[:1500],
)

try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=50,
temperature=0.3,
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.warning("Summary generation failed", func=func.name, error=str(e))
return ""

async def generate_batch(self, functions: List[ExtractedFunction]) -> List[str]:
"""Generate summaries for multiple functions in one API call."""
if not functions:
return []

func_texts = []
for i, func in enumerate(functions, 1):
func_texts.append(
f"{i}. {func.qualified_name}\n"
f" Signature: {func.signature}\n"
f" Code: {func.code[:800]}\n"
)

prompt = BATCH_PROMPT.format(functions="\n".join(func_texts))

try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=50 * len(functions),
temperature=0.3,
)

lines = response.choices[0].message.content.strip().split("\n")
summaries = []

for line in lines:
line = line.strip()
if not line:
continue
# strip leading number if present (e.g., "1. Summary here")
if line[0].isdigit() and ". " in line[:4]:
line = line.split(". ", 1)[1]
summaries.append(line)

# pad with empty strings if we got fewer summaries
while len(summaries) < len(functions):
summaries.append("")

return summaries[:len(functions)]

except Exception as e:
logger.warning("Batch summary failed", count=len(functions), error=str(e))
return [""] * len(functions)

async def generate_all(
self,
functions: List[ExtractedFunction],
progress_callback=None
) -> List[str]:
"""Generate summaries for all functions with batching."""
all_summaries = []
total = len(functions)

for i in range(0, total, self.batch_size):
batch = functions[i:i + self.batch_size]
summaries = await self.generate_batch(batch)
all_summaries.extend(summaries)

if progress_callback:
await progress_callback(len(all_summaries), total)

logger.debug("Summaries generated", progress=len(all_summaries), total=total)

return all_summaries


async def generate_summaries(
functions: List[ExtractedFunction],
batch_size: int = 10,
progress_callback=None
) -> List[str]:
"""Convenience function to generate summaries."""
generator = SummaryGenerator(batch_size=batch_size)
return await generator.generate_all(functions, progress_callback)
142 changes: 142 additions & 0 deletions backend/tests/test_summary_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Tests for AI summary generation."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock

from services.search_v2 import SummaryGenerator, ExtractedFunction, generate_summaries


def make_func(name: str, code: str = "def x(): pass") -> ExtractedFunction:
return ExtractedFunction(
name=name,
qualified_name=name,
file_path="test.py",
code=code,
signature=f"def {name}():",
language="python",
start_line=1,
end_line=2,
)


class TestSummaryGenerator:

@pytest.fixture
def generator(self):
return SummaryGenerator(batch_size=3)

@pytest.mark.asyncio
async def test_generate_single(self, generator):
func = make_func("process_data", "def process_data(items): return [x*2 for x in items]")

mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Doubles each item in the input list."))]

with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock:
mock.return_value = mock_response
summary = await generator.generate_single(func)

assert summary == "Doubles each item in the input list."
mock.assert_called_once()

@pytest.mark.asyncio
async def test_generate_batch(self, generator):
funcs = [
make_func("fetch_users"),
make_func("save_data"),
make_func("validate_input"),
]

mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(
content="1. Retrieves user records from database.\n2. Persists data to storage.\n3. Validates input parameters."
))]

with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock:
mock.return_value = mock_response
summaries = await generator.generate_batch(funcs)

assert len(summaries) == 3
assert "user" in summaries[0].lower()
assert "data" in summaries[1].lower() or "storage" in summaries[1].lower()
assert "valid" in summaries[2].lower()

@pytest.mark.asyncio
async def test_generate_batch_handles_fewer_responses(self, generator):
funcs = [make_func("a"), make_func("b"), make_func("c")]

mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Only one summary"))]

with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock:
mock.return_value = mock_response
summaries = await generator.generate_batch(funcs)

assert len(summaries) == 3
assert summaries[0] == "Only one summary"
assert summaries[1] == ""
assert summaries[2] == ""

@pytest.mark.asyncio
async def test_generate_all_batches_correctly(self, generator):
funcs = [make_func(f"func_{i}") for i in range(7)]

mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(
content="1. Summary one\n2. Summary two\n3. Summary three"
))]

with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock:
mock.return_value = mock_response
summaries = await generator.generate_all(funcs)

# batch_size=3, so 7 funcs = 3 API calls
assert mock.call_count == 3
assert len(summaries) == 7

@pytest.mark.asyncio
async def test_generate_single_handles_error(self, generator):
func = make_func("broken")

with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock:
mock.side_effect = Exception("API error")
summary = await generator.generate_single(func)

assert summary == ""

@pytest.mark.asyncio
async def test_generate_batch_handles_error(self, generator):
funcs = [make_func("a"), make_func("b")]

with patch.object(generator.client.chat.completions, 'create', new_callable=AsyncMock) as mock:
mock.side_effect = Exception("API error")
summaries = await generator.generate_batch(funcs)

assert summaries == ["", ""]

@pytest.mark.asyncio
async def test_empty_input(self, generator):
summaries = await generator.generate_batch([])
assert summaries == []

summaries = await generator.generate_all([])
assert summaries == []


class TestConvenienceFunction:

@pytest.mark.asyncio
async def test_generate_summaries_convenience(self):
funcs = [make_func("test_func")]

mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Test summary"))]

with patch('services.search_v2.summary_generator.AsyncOpenAI') as MockClient:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
MockClient.return_value = mock_client

summaries = await generate_summaries(funcs, batch_size=5)

assert len(summaries) == 1
assert summaries[0] == "Test summary"