Skip to content

Commit 25569af

Browse files
authored
Merge pull request #138 from DevanshuNEU/feat/function-level-chunking
feat(search-v2): implement function-level chunking with tree-sitter
2 parents c14e1ca + 1900061 commit 25569af

6 files changed

Lines changed: 855 additions & 1 deletion

File tree

backend/services/indexer_optimized.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
# Search enhancement
3232
from services.search_enhancer import SearchEnhancer
3333

34+
# Search V2 - Function-level extraction (Issue #68)
35+
from services.search_v2 import TreeSitterExtractor, FunctionFilter, ExtractedFunction
36+
3437
# Observability
3538
from services.observability import logger, trace_operation, track_time, capture_exception, add_breadcrumb, metrics
3639

@@ -89,6 +92,14 @@ def __init__(self):
8992
'typescript': self._create_parser(Language(tsjavascript.language())),
9093
}
9194

95+
# Search V2: Initialize advanced tree-sitter extractor and filter (Issue #68)
96+
self.tree_sitter_extractor = TreeSitterExtractor()
97+
self.function_filter = FunctionFilter(
98+
include_private=False,
99+
include_dunders=True,
100+
max_name_length=50
101+
)
102+
92103
logger.info("OptimizedCodeIndexer initialized", model=EMBEDDING_MODEL)
93104

94105
def _create_parser(self, language) -> Parser:
@@ -339,7 +350,89 @@ async def _extract_functions_from_file(
339350
except Exception as e:
340351
logger.error("Error processing file", file_path=file_path, error=str(e))
341352
return []
342-
353+
354+
def extract_functions_v2(self, repo_path: str, max_functions: int = 5000) -> List[ExtractedFunction]:
355+
"""Extract and filter functions using tree-sitter."""
356+
from pathlib import Path
357+
358+
raw = self.tree_sitter_extractor.extract_from_repo(Path(repo_path), max_functions=max_functions)
359+
filtered = self.function_filter.filter_functions(raw)
360+
361+
logger.info("V2 extraction", total=len(raw), kept=len(filtered))
362+
return filtered
363+
364+
def _build_embedding_text(self, func: ExtractedFunction) -> str:
365+
"""Build rich text for embedding."""
366+
parts = [
367+
f"Function: {func.qualified_name}",
368+
f"Signature: {func.signature}",
369+
]
370+
if func.docstring:
371+
parts.append(f"Description: {func.docstring[:500]}")
372+
parts.append(f"Language: {func.language}")
373+
parts.append(f"Code:\n{func.code[:2000]}")
374+
return "\n".join(parts)
375+
376+
def _build_metadata(self, func: ExtractedFunction, repo_id: str) -> Dict:
377+
"""Build Pinecone metadata from function."""
378+
return {
379+
"repo_id": repo_id,
380+
"file_path": func.file_path,
381+
"name": func.name,
382+
"qualified_name": func.qualified_name,
383+
"type": "method" if func.is_method else "function",
384+
"code": func.code[:1000],
385+
"signature": func.signature,
386+
"start_line": func.start_line,
387+
"end_line": func.end_line,
388+
"language": func.language,
389+
"class_name": func.class_name or "",
390+
"docstring": (func.docstring or "")[:500],
391+
"is_async": func.is_async,
392+
}
393+
394+
async def index_repository_v2(self, repo_id: str, repo_path: str, progress_callback=None) -> int:
395+
"""Index repository using V2 function-level extraction."""
396+
start_time = time.time()
397+
logger.info("V2 indexing started", repo_id=repo_id)
398+
399+
functions = self.extract_functions_v2(repo_path)
400+
if not functions:
401+
if progress_callback:
402+
await progress_callback(0, 0, 0)
403+
return 0
404+
405+
# generate embeddings
406+
texts = [self._build_embedding_text(f) for f in functions]
407+
embeddings = []
408+
409+
for i in range(0, len(texts), self.EMBEDDING_BATCH_SIZE):
410+
batch = texts[i:i + self.EMBEDDING_BATCH_SIZE]
411+
embeddings.extend(await self._create_embeddings_batch(batch))
412+
413+
if progress_callback:
414+
await progress_callback(len(embeddings), len(functions), len(functions))
415+
416+
# build vectors
417+
vectors = [
418+
{
419+
"id": hashlib.md5(func.id_string.encode()).hexdigest(),
420+
"values": emb,
421+
"metadata": self._build_metadata(func, repo_id)
422+
}
423+
for func, emb in zip(functions, embeddings)
424+
]
425+
426+
# upsert to pinecone
427+
for i in range(0, len(vectors), self.PINECONE_UPSERT_BATCH):
428+
self.index.upsert(vectors=vectors[i:i + self.PINECONE_UPSERT_BATCH])
429+
430+
elapsed = time.time() - start_time
431+
logger.info("V2 indexing complete", repo_id=repo_id, functions=len(functions), duration_s=round(elapsed, 2))
432+
metrics.increment("indexing_v2_completed")
433+
434+
return len(functions)
435+
343436
async def semantic_search(
344437
self,
345438
query: str,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Search V2: Function-level semantic search."""
2+
from services.search_v2.types import ExtractedFunction, SearchResult, Language
3+
from services.search_v2.tree_sitter_extractor import TreeSitterExtractor
4+
from services.search_v2.function_filter import FunctionFilter, filter_functions
5+
6+
__all__ = [
7+
"ExtractedFunction",
8+
"SearchResult",
9+
"Language",
10+
"TreeSitterExtractor",
11+
"FunctionFilter",
12+
"filter_functions",
13+
]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Filter out low-quality functions from search index."""
2+
from typing import List, Set
3+
from services.search_v2.types import ExtractedFunction
4+
from services.observability import logger
5+
6+
# Junk function name prefixes
7+
JUNK_PREFIXES = (
8+
'test_', 'time_', 'rand_', 'mock_', 'fake_', 'stub_',
9+
'setup_', 'teardown_', 'fixture_', 'check_',
10+
'_test_', '_time_', '_rand_', '_mock_',
11+
'assert_', 'verify_', 'validate_test',
12+
)
13+
14+
# Junk patterns anywhere in name
15+
JUNK_PATTERNS = (
16+
'_fixture', '_setup', '_teardown', '_helper_test',
17+
'benchmark', '_bench', '_perf_',
18+
'_random_data', '_test_data', '_sample_data',
19+
'from_int_dict', 'from_test', '_for_test',
20+
'_for_split', 'create_data_for',
21+
'doesnt_use_', 'check_main',
22+
)
23+
24+
# Junk file paths
25+
JUNK_PATHS = (
26+
'tests/', 'test/', 'testing/', '/tests/', '/test/', '/testing/',
27+
'benchmarks/', 'asv_bench/', 'bench/',
28+
'examples/', 'docs/', 'doc/',
29+
'_testing/', 'conftest',
30+
'fixtures/', '_fixtures/',
31+
'mock/', 'mocks/', 'stubs/',
32+
)
33+
34+
# Keep these even if they match junk patterns
35+
PUBLIC_API: Set[str] = {
36+
'read_csv', 'read_excel', 'read_json', 'read_parquet', 'read_sql',
37+
'to_csv', 'to_excel', 'to_json', 'to_parquet', 'to_sql',
38+
'merge', 'concat', 'groupby', 'pivot', 'melt',
39+
'fillna', 'dropna', 'isna', 'notna',
40+
'apply', 'map', 'transform', 'agg', 'aggregate',
41+
'sort_values', 'sort_index', 'reset_index', 'set_index',
42+
'authenticate', 'authorize', 'login', 'logout',
43+
'validate', 'serialize', 'deserialize',
44+
'create', 'read', 'update', 'delete',
45+
'get', 'set', 'post', 'put', 'patch',
46+
'connect', 'disconnect', 'send', 'receive',
47+
'parse', 'format', 'convert', 'transform',
48+
'load', 'save', 'export', 'import_',
49+
'init', 'setup', 'configure', 'initialize',
50+
}
51+
52+
53+
class FunctionFilter:
54+
"""Filter functions to keep only high-quality, searchable ones."""
55+
56+
def __init__(
57+
self,
58+
include_private: bool = False,
59+
include_dunders: bool = True,
60+
max_name_length: int = 50,
61+
):
62+
self.include_private = include_private
63+
self.include_dunders = include_dunders
64+
self.max_name_length = max_name_length
65+
66+
def filter_functions(self, functions: List[ExtractedFunction]) -> List[ExtractedFunction]:
67+
original = len(functions)
68+
filtered = [f for f in functions if self._keep(f)]
69+
70+
if original - len(filtered) > 0:
71+
logger.debug("Filtered functions", kept=len(filtered), removed=original - len(filtered))
72+
73+
return filtered
74+
75+
def _keep(self, func: ExtractedFunction) -> bool:
76+
name = func.name.lower()
77+
path = func.file_path.lower()
78+
79+
# always keep public API
80+
if any(api in name for api in PUBLIC_API):
81+
return True
82+
83+
# skip junk paths
84+
if any(p in path for p in JUNK_PATHS):
85+
return False
86+
87+
# skip junk prefixes
88+
if name.startswith(JUNK_PREFIXES):
89+
return False
90+
91+
# skip junk patterns
92+
if any(p in name for p in JUNK_PATTERNS):
93+
return False
94+
95+
# skip long auto-generated names
96+
if len(name) > self.max_name_length:
97+
return False
98+
99+
# handle private functions
100+
if func.name.startswith('_') and not func.name.startswith('__'):
101+
return self.include_private
102+
103+
# handle dunders
104+
if func.name.startswith('__') and func.name.endswith('__'):
105+
return self.include_dunders
106+
107+
# skip test data generators
108+
if name.startswith('make_') and ('test' in path or 'random' in name):
109+
return False
110+
111+
return True
112+
113+
def get_stats(self, functions: List[ExtractedFunction]) -> dict:
114+
quality = [f for f in functions if self._keep(f)]
115+
return {
116+
"total": len(functions),
117+
"kept": len(quality),
118+
"removed": len(functions) - len(quality),
119+
}
120+
121+
122+
default_filter = FunctionFilter()
123+
124+
125+
def filter_functions(functions: List[ExtractedFunction]) -> List[ExtractedFunction]:
126+
"""Filter using default settings."""
127+
return default_filter.filter_functions(functions)

0 commit comments

Comments
 (0)