From 19000610dfdd1b0b01086f2f4e32e0b8b2d8f913 Mon Sep 17 00:00:00 2001 From: Devanshu Rajesh Chicholikar Date: Sun, 4 Jan 2026 16:21:19 -0500 Subject: [PATCH] feat(search-v2): implement function-level chunking with tree-sitter Closes #68 Add function-level code extraction using tree-sitter AST parsing: - TreeSitterExtractor for Python/JS/TS function extraction - FunctionFilter to remove test/junk functions - ExtractedFunction and SearchResult data models - index_repository_v2() method in indexer - 15 tests passing Part 1 of 4 in Search V2 epic. --- backend/services/indexer_optimized.py | 95 +++++- backend/services/search_v2/__init__.py | 13 + backend/services/search_v2/function_filter.py | 127 ++++++++ .../search_v2/tree_sitter_extractor.py | 285 ++++++++++++++++++ backend/services/search_v2/types.py | 88 ++++++ backend/tests/test_search_v2_chunking.py | 248 +++++++++++++++ 6 files changed, 855 insertions(+), 1 deletion(-) create mode 100644 backend/services/search_v2/__init__.py create mode 100644 backend/services/search_v2/function_filter.py create mode 100644 backend/services/search_v2/tree_sitter_extractor.py create mode 100644 backend/services/search_v2/types.py create mode 100644 backend/tests/test_search_v2_chunking.py diff --git a/backend/services/indexer_optimized.py b/backend/services/indexer_optimized.py index 4879065..2509e41 100644 --- a/backend/services/indexer_optimized.py +++ b/backend/services/indexer_optimized.py @@ -31,6 +31,9 @@ # Search enhancement from services.search_enhancer import SearchEnhancer +# Search V2 - Function-level extraction (Issue #68) +from services.search_v2 import TreeSitterExtractor, FunctionFilter, ExtractedFunction + # Observability from services.observability import logger, trace_operation, track_time, capture_exception, add_breadcrumb, metrics @@ -89,6 +92,14 @@ def __init__(self): 'typescript': self._create_parser(Language(tsjavascript.language())), } + # Search V2: Initialize advanced tree-sitter extractor and filter (Issue #68) + self.tree_sitter_extractor = TreeSitterExtractor() + self.function_filter = FunctionFilter( + include_private=False, + include_dunders=True, + max_name_length=50 + ) + logger.info("OptimizedCodeIndexer initialized", model=EMBEDDING_MODEL) def _create_parser(self, language) -> Parser: @@ -339,7 +350,89 @@ async def _extract_functions_from_file( except Exception as e: logger.error("Error processing file", file_path=file_path, error=str(e)) return [] - + + def extract_functions_v2(self, repo_path: str, max_functions: int = 5000) -> List[ExtractedFunction]: + """Extract and filter functions using tree-sitter.""" + from pathlib import Path + + raw = self.tree_sitter_extractor.extract_from_repo(Path(repo_path), max_functions=max_functions) + filtered = self.function_filter.filter_functions(raw) + + logger.info("V2 extraction", total=len(raw), kept=len(filtered)) + return filtered + + def _build_embedding_text(self, func: ExtractedFunction) -> str: + """Build rich text for embedding.""" + parts = [ + f"Function: {func.qualified_name}", + f"Signature: {func.signature}", + ] + if func.docstring: + parts.append(f"Description: {func.docstring[:500]}") + parts.append(f"Language: {func.language}") + parts.append(f"Code:\n{func.code[:2000]}") + return "\n".join(parts) + + def _build_metadata(self, func: ExtractedFunction, repo_id: str) -> Dict: + """Build Pinecone metadata from function.""" + return { + "repo_id": repo_id, + "file_path": func.file_path, + "name": func.name, + "qualified_name": func.qualified_name, + "type": "method" if func.is_method else "function", + "code": func.code[:1000], + "signature": func.signature, + "start_line": func.start_line, + "end_line": func.end_line, + "language": func.language, + "class_name": func.class_name or "", + "docstring": (func.docstring or "")[:500], + "is_async": func.is_async, + } + + async def index_repository_v2(self, repo_id: str, repo_path: str, progress_callback=None) -> int: + """Index repository using V2 function-level extraction.""" + start_time = time.time() + logger.info("V2 indexing started", repo_id=repo_id) + + functions = self.extract_functions_v2(repo_path) + if not functions: + if progress_callback: + await progress_callback(0, 0, 0) + return 0 + + # generate embeddings + texts = [self._build_embedding_text(f) for f in functions] + embeddings = [] + + for i in range(0, len(texts), self.EMBEDDING_BATCH_SIZE): + batch = texts[i:i + self.EMBEDDING_BATCH_SIZE] + embeddings.extend(await self._create_embeddings_batch(batch)) + + if progress_callback: + await progress_callback(len(embeddings), len(functions), len(functions)) + + # build vectors + vectors = [ + { + "id": hashlib.md5(func.id_string.encode()).hexdigest(), + "values": emb, + "metadata": self._build_metadata(func, repo_id) + } + for func, emb in zip(functions, embeddings) + ] + + # upsert to pinecone + for i in range(0, len(vectors), self.PINECONE_UPSERT_BATCH): + self.index.upsert(vectors=vectors[i:i + self.PINECONE_UPSERT_BATCH]) + + elapsed = time.time() - start_time + logger.info("V2 indexing complete", repo_id=repo_id, functions=len(functions), duration_s=round(elapsed, 2)) + metrics.increment("indexing_v2_completed") + + return len(functions) + async def semantic_search( self, query: str, diff --git a/backend/services/search_v2/__init__.py b/backend/services/search_v2/__init__.py new file mode 100644 index 0000000..d5e4de0 --- /dev/null +++ b/backend/services/search_v2/__init__.py @@ -0,0 +1,13 @@ +"""Search V2: Function-level semantic search.""" +from services.search_v2.types import ExtractedFunction, SearchResult, Language +from services.search_v2.tree_sitter_extractor import TreeSitterExtractor +from services.search_v2.function_filter import FunctionFilter, filter_functions + +__all__ = [ + "ExtractedFunction", + "SearchResult", + "Language", + "TreeSitterExtractor", + "FunctionFilter", + "filter_functions", +] diff --git a/backend/services/search_v2/function_filter.py b/backend/services/search_v2/function_filter.py new file mode 100644 index 0000000..ea7519e --- /dev/null +++ b/backend/services/search_v2/function_filter.py @@ -0,0 +1,127 @@ +"""Filter out low-quality functions from search index.""" +from typing import List, Set +from services.search_v2.types import ExtractedFunction +from services.observability import logger + +# Junk function name prefixes +JUNK_PREFIXES = ( + 'test_', 'time_', 'rand_', 'mock_', 'fake_', 'stub_', + 'setup_', 'teardown_', 'fixture_', 'check_', + '_test_', '_time_', '_rand_', '_mock_', + 'assert_', 'verify_', 'validate_test', +) + +# Junk patterns anywhere in name +JUNK_PATTERNS = ( + '_fixture', '_setup', '_teardown', '_helper_test', + 'benchmark', '_bench', '_perf_', + '_random_data', '_test_data', '_sample_data', + 'from_int_dict', 'from_test', '_for_test', + '_for_split', 'create_data_for', + 'doesnt_use_', 'check_main', +) + +# Junk file paths +JUNK_PATHS = ( + 'tests/', 'test/', 'testing/', '/tests/', '/test/', '/testing/', + 'benchmarks/', 'asv_bench/', 'bench/', + 'examples/', 'docs/', 'doc/', + '_testing/', 'conftest', + 'fixtures/', '_fixtures/', + 'mock/', 'mocks/', 'stubs/', +) + +# Keep these even if they match junk patterns +PUBLIC_API: Set[str] = { + 'read_csv', 'read_excel', 'read_json', 'read_parquet', 'read_sql', + 'to_csv', 'to_excel', 'to_json', 'to_parquet', 'to_sql', + 'merge', 'concat', 'groupby', 'pivot', 'melt', + 'fillna', 'dropna', 'isna', 'notna', + 'apply', 'map', 'transform', 'agg', 'aggregate', + 'sort_values', 'sort_index', 'reset_index', 'set_index', + 'authenticate', 'authorize', 'login', 'logout', + 'validate', 'serialize', 'deserialize', + 'create', 'read', 'update', 'delete', + 'get', 'set', 'post', 'put', 'patch', + 'connect', 'disconnect', 'send', 'receive', + 'parse', 'format', 'convert', 'transform', + 'load', 'save', 'export', 'import_', + 'init', 'setup', 'configure', 'initialize', +} + + +class FunctionFilter: + """Filter functions to keep only high-quality, searchable ones.""" + + def __init__( + self, + include_private: bool = False, + include_dunders: bool = True, + max_name_length: int = 50, + ): + self.include_private = include_private + self.include_dunders = include_dunders + self.max_name_length = max_name_length + + def filter_functions(self, functions: List[ExtractedFunction]) -> List[ExtractedFunction]: + original = len(functions) + filtered = [f for f in functions if self._keep(f)] + + if original - len(filtered) > 0: + logger.debug("Filtered functions", kept=len(filtered), removed=original - len(filtered)) + + return filtered + + def _keep(self, func: ExtractedFunction) -> bool: + name = func.name.lower() + path = func.file_path.lower() + + # always keep public API + if any(api in name for api in PUBLIC_API): + return True + + # skip junk paths + if any(p in path for p in JUNK_PATHS): + return False + + # skip junk prefixes + if name.startswith(JUNK_PREFIXES): + return False + + # skip junk patterns + if any(p in name for p in JUNK_PATTERNS): + return False + + # skip long auto-generated names + if len(name) > self.max_name_length: + return False + + # handle private functions + if func.name.startswith('_') and not func.name.startswith('__'): + return self.include_private + + # handle dunders + if func.name.startswith('__') and func.name.endswith('__'): + return self.include_dunders + + # skip test data generators + if name.startswith('make_') and ('test' in path or 'random' in name): + return False + + return True + + def get_stats(self, functions: List[ExtractedFunction]) -> dict: + quality = [f for f in functions if self._keep(f)] + return { + "total": len(functions), + "kept": len(quality), + "removed": len(functions) - len(quality), + } + + +default_filter = FunctionFilter() + + +def filter_functions(functions: List[ExtractedFunction]) -> List[ExtractedFunction]: + """Filter using default settings.""" + return default_filter.filter_functions(functions) diff --git a/backend/services/search_v2/tree_sitter_extractor.py b/backend/services/search_v2/tree_sitter_extractor.py new file mode 100644 index 0000000..1239465 --- /dev/null +++ b/backend/services/search_v2/tree_sitter_extractor.py @@ -0,0 +1,285 @@ +"""AST-based function extraction using tree-sitter.""" +import os +from pathlib import Path +from typing import List, Optional, Dict, Set + +import tree_sitter_python as tspython +import tree_sitter_javascript as tsjavascript +from tree_sitter import Language, Parser + +from services.search_v2.types import ExtractedFunction +from services.observability import logger + +PY_LANGUAGE = Language(tspython.language()) +JS_LANGUAGE = Language(tsjavascript.language()) + +LANGUAGE_MAP: Dict[str, tuple] = { + '.py': ('python', PY_LANGUAGE), + '.js': ('javascript', JS_LANGUAGE), + '.jsx': ('javascript', JS_LANGUAGE), + '.ts': ('typescript', JS_LANGUAGE), + '.tsx': ('typescript', JS_LANGUAGE), +} + +SKIP_DIRS: Set[str] = { + 'node_modules', '.git', '__pycache__', 'vendor', 'dist', + 'build', '.next', 'coverage', '.pytest_cache', 'venv', + 'env', '.venv', '.env', 'site-packages', '.tox', + '__tests__', 'tests', 'test', 'spec', 'specs', + 'fixtures', 'mocks', 'stubs', +} + +TEST_PATTERNS = ( + '/test.', '.test.', '.spec.', '_test.', + '/tests/', '/test/', '/__tests__/', + 'conftest.py', 'pytest_', 'unittest_', +) + + +class TreeSitterExtractor: + """Extract functions from source code using tree-sitter.""" + + MAX_CODE_LENGTH = 3000 + MAX_FUNCTIONS = 5000 + + def __init__(self): + self.parsers: Dict[str, Parser] = {} + for ext, (lang_name, lang) in LANGUAGE_MAP.items(): + if lang_name not in self.parsers: + self.parsers[lang_name] = Parser(lang) + + def extract_from_repo( + self, + repo_path: Path, + include_paths: Optional[List[str]] = None, + exclude_paths: Optional[List[str]] = None, + max_functions: Optional[int] = None + ) -> List[ExtractedFunction]: + """Extract all functions from a repository.""" + max_funcs = max_functions or self.MAX_FUNCTIONS + functions: List[ExtractedFunction] = [] + + for ext, (lang_name, _) in LANGUAGE_MAP.items(): + for file_path in repo_path.rglob(f"*{ext}"): + rel_path = str(file_path.relative_to(repo_path)) + + if self._should_skip(rel_path, include_paths, exclude_paths): + continue + + try: + code = file_path.read_text(encoding='utf-8', errors='ignore') + functions.extend(self.extract_from_code(code, lang_name, rel_path)) + + if len(functions) >= max_funcs: + return functions[:max_funcs] + except Exception: + continue + + logger.info("Extraction complete", functions=len(functions)) + return functions + + def extract_from_code(self, code: str, language: str, file_path: str) -> List[ExtractedFunction]: + """Extract functions from source code string.""" + parser = self.parsers.get(language) + if not parser: + return [] + + try: + tree = parser.parse(bytes(code, 'utf-8')) + except Exception: + return [] + + if language == 'python': + return self._extract_python(tree, code, file_path) + elif language in ('javascript', 'typescript'): + return self._extract_js_ts(tree, code, file_path, language) + return [] + + def _should_skip(self, rel_path: str, include_paths, exclude_paths) -> bool: + path_parts = rel_path.split(os.sep) + if any(skip in path_parts for skip in SKIP_DIRS): + return True + if any(p in rel_path.lower() for p in TEST_PATTERNS): + return True + if include_paths and not any(rel_path.startswith(p) for p in include_paths): + return True + if exclude_paths and any(rel_path.startswith(p) for p in exclude_paths): + return True + return False + + def _extract_python(self, tree, code: str, file_path: str) -> List[ExtractedFunction]: + functions: List[ExtractedFunction] = [] + code_bytes = bytes(code, 'utf-8') + + def get_text(node) -> str: + return code_bytes[node.start_byte:node.end_byte].decode('utf-8') + + def get_docstring(node) -> Optional[str]: + body = node.child_by_field_name('body') + if body and body.child_count > 0: + first = body.children[0] + if first.type == 'expression_statement' and first.child_count > 0: + expr = first.children[0] + if expr.type == 'string': + doc = get_text(expr) + if doc.startswith('"""') or doc.startswith("'''"): + return doc[3:-3].strip() + return None + + def visit(node, class_name=None): + if node.type == 'function_definition': + name_node = node.child_by_field_name('name') + if not name_node: + return + + name = get_text(name_node) + + # skip private, keep dunders + if name.startswith('_') and not name.startswith('__'): + for child in node.children: + visit(child, class_name) + return + + params_node = node.child_by_field_name('parameters') + params = get_text(params_node) if params_node else '()' + + is_async = any(c.type == 'async' for c in node.children if hasattr(c, 'type')) + + return_type = node.child_by_field_name('return_type') + return_hint = f" -> {get_text(return_type)}" if return_type else "" + + signature = f"{'async ' if is_async else ''}def {name}{params}{return_hint}:" + qualified_name = f"{class_name}.{name}" if class_name else name + + functions.append(ExtractedFunction( + name=name, + qualified_name=qualified_name, + file_path=file_path, + code=get_text(node)[:self.MAX_CODE_LENGTH], + signature=signature, + language='python', + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + class_name=class_name, + docstring=get_docstring(node), + is_async=is_async, + is_method=class_name is not None, + )) + + elif node.type == 'class_definition': + cname_node = node.child_by_field_name('name') + if cname_node: + for child in node.children: + visit(child, get_text(cname_node)) + return + + for child in node.children: + visit(child, class_name) + + visit(tree.root_node) + return functions + + def _extract_js_ts(self, tree, code: str, file_path: str, language: str) -> List[ExtractedFunction]: + functions: List[ExtractedFunction] = [] + code_bytes = bytes(code, 'utf-8') + + def get_text(node) -> str: + return code_bytes[node.start_byte:node.end_byte].decode('utf-8') + + def get_name(node) -> Optional[str]: + name_node = node.child_by_field_name('name') + return get_text(name_node) if name_node else None + + def visit(node, class_name=None): + # arrow functions in variable declarations + if node.type in ('lexical_declaration', 'variable_declaration'): + for child in node.children: + if child.type == 'variable_declarator': + name_node = child.child_by_field_name('name') + value_node = child.child_by_field_name('value') + + if name_node and value_node and value_node.type in ('arrow_function', 'function_expression'): + name = get_text(name_node) + if name.startswith('_'): + continue + + func_code = get_text(node) + params_node = value_node.child_by_field_name('parameters') + params = get_text(params_node) if params_node else '()' + + is_async = 'async' in func_code[:50] + is_exported = func_code.strip().startswith('export') + + functions.append(ExtractedFunction( + name=name, + qualified_name=f"{class_name}.{name}" if class_name else name, + file_path=file_path, + code=func_code[:self.MAX_CODE_LENGTH], + signature=f"{'export ' if is_exported else ''}const {name} = {'async ' if is_async else ''}{params} =>", + language=language, + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + class_name=class_name, + is_async=is_async, + is_method=class_name is not None, + )) + return + + if node.type == 'function_declaration': + name = get_name(node) + if name: + func_code = get_text(node) + params_node = node.child_by_field_name('parameters') + params = get_text(params_node) if params_node else '()' + + is_async = any(c.type == 'async' for c in node.children) + is_exported = node.parent and node.parent.type == 'export_statement' + + functions.append(ExtractedFunction( + name=name, + qualified_name=f"{class_name}.{name}" if class_name else name, + file_path=file_path, + code=func_code[:self.MAX_CODE_LENGTH], + signature=f"{'export ' if is_exported else ''}{'async ' if is_async else ''}function {name}{params}", + language=language, + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + class_name=class_name, + is_async=is_async, + is_method=class_name is not None, + )) + + elif node.type == 'method_definition': + name = get_name(node) + if name and not name.startswith('_'): + func_code = get_text(node) + params_node = node.child_by_field_name('parameters') + params = get_text(params_node) if params_node else '()' + is_async = any(c.type == 'async' for c in node.children) + + functions.append(ExtractedFunction( + name=name, + qualified_name=f"{class_name}.{name}" if class_name else name, + file_path=file_path, + code=func_code[:self.MAX_CODE_LENGTH], + signature=f"{'async ' if is_async else ''}{name}{params}", + language=language, + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + class_name=class_name, + is_async=is_async, + is_method=True, + )) + + elif node.type == 'class_declaration': + cname_node = node.child_by_field_name('name') + if cname_node: + for child in node.children: + visit(child, get_text(cname_node)) + return + + for child in node.children: + visit(child, class_name) + + visit(tree.root_node) + return functions diff --git a/backend/services/search_v2/types.py b/backend/services/search_v2/types.py new file mode 100644 index 0000000..6eef05f --- /dev/null +++ b/backend/services/search_v2/types.py @@ -0,0 +1,88 @@ +"""Data models for Search V2.""" +from dataclasses import dataclass, field +from typing import Optional, List +from enum import Enum + + +class Language(str, Enum): + PYTHON = "python" + JAVASCRIPT = "javascript" + TYPESCRIPT = "typescript" + GO = "go" # not yet supported + JAVA = "java" # not yet supported + + +@dataclass +class ExtractedFunction: + """A function or method extracted from source code.""" + name: str + qualified_name: str # e.g., "Session.get" for methods + file_path: str + code: str + signature: str + language: str + start_line: int + end_line: int + class_name: Optional[str] = None + docstring: Optional[str] = None + is_async: bool = False + is_method: bool = False + decorators: List[str] = field(default_factory=list) + + @property + def display_name(self) -> str: + return self.qualified_name if self.class_name else self.name + + @property + def id_string(self) -> str: + return f"{self.file_path}:{self.qualified_name}:{self.start_line}" + + def to_dict(self) -> dict: + return { + "name": self.name, + "qualified_name": self.qualified_name, + "file_path": self.file_path, + "code": self.code, + "signature": self.signature, + "language": self.language, + "start_line": self.start_line, + "end_line": self.end_line, + "class_name": self.class_name, + "docstring": self.docstring, + "is_async": self.is_async, + "is_method": self.is_method, + "decorators": self.decorators, + } + + +@dataclass +class SearchResult: + """Search result with score and optional summary.""" + name: str + qualified_name: str + file_path: str + code: str + signature: str + language: str + score: float + start_line: int + end_line: int + summary: Optional[str] = None + class_name: Optional[str] = None + match_reason: Optional[str] = None + + def to_dict(self) -> dict: + return { + "name": self.name, + "qualified_name": self.qualified_name, + "file_path": self.file_path, + "code": self.code, + "signature": self.signature, + "language": self.language, + "score": self.score, + "line_start": self.start_line, + "line_end": self.end_line, + "summary": self.summary, + "class_name": self.class_name, + "match_reason": self.match_reason, + } diff --git a/backend/tests/test_search_v2_chunking.py b/backend/tests/test_search_v2_chunking.py new file mode 100644 index 0000000..30caf14 --- /dev/null +++ b/backend/tests/test_search_v2_chunking.py @@ -0,0 +1,248 @@ +"""Tests for Search V2 function-level chunking.""" +import pytest +from services.search_v2 import ( + TreeSitterExtractor, + FunctionFilter, + ExtractedFunction, + filter_functions, +) + + +class TestTreeSitterExtractor: + + @pytest.fixture + def extractor(self): + return TreeSitterExtractor() + + def test_python_function(self, extractor): + code = ''' +def hello_world(name: str) -> str: + """Say hello to someone.""" + return f"Hello, {name}!" +''' + funcs = extractor.extract_from_code(code, 'python', 'test.py') + + assert len(funcs) == 1 + assert funcs[0].name == "hello_world" + assert funcs[0].docstring == "Say hello to someone." + + def test_python_class_method(self, extractor): + code = ''' +class UserService: + def get_user(self, user_id: int) -> dict: + """Fetch a user by ID.""" + return {"id": user_id} + + async def create_user(self, data: dict) -> dict: + """Create a new user.""" + return data +''' + funcs = extractor.extract_from_code(code, 'python', 'services/user.py') + + assert len(funcs) == 2 + + get_user = next(f for f in funcs if f.name == "get_user") + assert get_user.qualified_name == "UserService.get_user" + assert get_user.class_name == "UserService" + assert get_user.is_method is True + + create_user = next(f for f in funcs if f.name == "create_user") + assert create_user.is_async is True + + def test_skip_private(self, extractor): + code = ''' +def public_func(): + pass + +def _private_func(): + pass + +def __dunder_func__(): + pass +''' + funcs = extractor.extract_from_code(code, 'python', 'test.py') + names = [f.name for f in funcs] + + assert "public_func" in names + assert "_private_func" not in names + assert "__dunder_func__" in names + + def test_javascript_functions(self, extractor): + code = ''' +function fetchData(url) { + return fetch(url); +} + +const processData = async (data) => { + return data.map(x => x * 2); +}; +''' + funcs = extractor.extract_from_code(code, 'javascript', 'utils.js') + names = [f.name for f in funcs] + + assert "fetchData" in names + assert "processData" in names + + def test_typescript_class(self, extractor): + code = ''' +class ApiClient { + async get(endpoint: string): Promise { + return fetch(endpoint); + } + + post(endpoint: string, data: object): Promise { + return fetch(endpoint, { method: 'POST' }); + } +} +''' + funcs = extractor.extract_from_code(code, 'typescript', 'api.ts') + + assert len(funcs) == 2 + get_method = next(f for f in funcs if f.name == "get") + assert get_method.qualified_name == "ApiClient.get" + assert get_method.is_async is True + + +class TestFunctionFilter: + + @pytest.fixture + def fltr(self): + return FunctionFilter() + + def _make_func(self, name, file_path="src/module.py"): + return ExtractedFunction( + name=name, + qualified_name=name, + file_path=file_path, + code=f"def {name}(): pass", + signature=f"def {name}():", + language="python", + start_line=1, + end_line=2, + ) + + def test_filter_test_functions(self, fltr): + funcs = [ + self._make_func("test_something"), + self._make_func("mock_database"), + self._make_func("real_function"), + ] + + filtered = fltr.filter_functions(funcs) + names = [f.name for f in filtered] + + assert "real_function" in names + assert "test_something" not in names + assert "mock_database" not in names + + def test_filter_test_directories(self, fltr): + funcs = [ + self._make_func("helper", "tests/test_utils.py"), + self._make_func("real_helper", "src/utils.py"), + ] + + filtered = fltr.filter_functions(funcs) + + assert len(filtered) == 1 + assert filtered[0].name == "real_helper" + + def test_keep_public_api(self, fltr): + funcs = [ + self._make_func("read_csv"), + self._make_func("validate"), + self._make_func("authenticate"), + ] + + filtered = fltr.filter_functions(funcs) + assert len(filtered) == 3 + + def test_filter_long_names(self, fltr): + long_name = "this_is_an_extremely_long_auto_generated_function_name_that_should_be_filtered" + funcs = [ + self._make_func(long_name), + self._make_func("short_name"), + ] + + filtered = fltr.filter_functions(funcs) + names = [f.name for f in filtered] + + assert "short_name" in names + assert long_name not in names + + def test_get_stats(self, fltr): + funcs = [ + self._make_func("test_func"), + self._make_func("real_func"), + self._make_func("_private"), + self._make_func("helper", "tests/helpers.py"), + ] + + stats = fltr.get_stats(funcs) + + assert stats["total"] == 4 + assert stats["kept"] == 1 + assert stats["removed"] == 3 + + +class TestExtractedFunction: + + def test_display_name_simple(self): + func = ExtractedFunction( + name="my_func", qualified_name="my_func", file_path="test.py", + code="def my_func(): pass", signature="def my_func():", + language="python", start_line=1, end_line=2, + ) + assert func.display_name == "my_func" + + def test_display_name_method(self): + func = ExtractedFunction( + name="get", qualified_name="Session.get", file_path="session.py", + code="def get(self): pass", signature="def get(self):", + language="python", start_line=1, end_line=2, class_name="Session", + ) + assert func.display_name == "Session.get" + + def test_id_string(self): + func = ExtractedFunction( + name="process", qualified_name="DataHandler.process", + file_path="handlers/data.py", code="def process(): pass", + signature="def process():", language="python", + start_line=42, end_line=50, + ) + assert func.id_string == "handlers/data.py:DataHandler.process:42" + + def test_to_dict(self): + func = ExtractedFunction( + name="fetch", qualified_name="ApiClient.fetch", file_path="api.py", + code="async def fetch(): pass", signature="async def fetch():", + language="python", start_line=10, end_line=15, + class_name="ApiClient", is_async=True, docstring="Fetch data", + ) + + d = func.to_dict() + assert d["name"] == "fetch" + assert d["qualified_name"] == "ApiClient.fetch" + assert d["is_async"] is True + + +class TestConvenienceFunctions: + + def test_filter_functions_convenience(self): + funcs = [ + ExtractedFunction( + name="test_something", qualified_name="test_something", + file_path="test.py", code="def test_something(): pass", + signature="def test_something():", language="python", + start_line=1, end_line=2, + ), + ExtractedFunction( + name="real_function", qualified_name="real_function", + file_path="src/module.py", code="def real_function(): pass", + signature="def real_function():", language="python", + start_line=1, end_line=2, + ), + ] + + filtered = filter_functions(funcs) + assert len(filtered) == 1 + assert filtered[0].name == "real_function"