From 421a6251474571ce8e2e141c1b125836ded5c169 Mon Sep 17 00:00:00 2001 From: Devanshu Rajesh Chicholikar Date: Sun, 4 Jan 2026 16:46:06 -0500 Subject: [PATCH] feat(search-v2): add hybrid search with BM25 fusion and Cohere reranking Closes #70 Add hybrid search combining semantic and keyword matching: - HybridSearcher with BM25 + semantic fusion (RRF algorithm) - Cohere rerank-v3.5 for final result ordering - search_v2() method in indexer - Graceful fallback when Cohere unavailable - 8 tests passing (31 total for Search V2) Dependencies: rank-bm25, cohere Part 3 of 4 in Search V2 epic. --- backend/requirements.txt | 4 + backend/services/indexer_optimized.py | 47 ++++- backend/services/search_v2/__init__.py | 2 + backend/services/search_v2/hybrid_searcher.py | 192 ++++++++++++++++++ backend/tests/test_hybrid_searcher.py | 162 +++++++++++++++ 5 files changed, 406 insertions(+), 1 deletion(-) create mode 100644 backend/services/search_v2/hybrid_searcher.py create mode 100644 backend/tests/test_hybrid_searcher.py diff --git a/backend/requirements.txt b/backend/requirements.txt index d066650..5882c8b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -39,3 +39,7 @@ pytest-cov>=6.0.0 # Observability sentry-sdk[fastapi]>=2.0.0 + +# Search V2 - Hybrid search +rank-bm25>=0.2.2 +cohere>=5.0.0 diff --git a/backend/services/indexer_optimized.py b/backend/services/indexer_optimized.py index 2ecd8b7..4eb9db7 100644 --- a/backend/services/indexer_optimized.py +++ b/backend/services/indexer_optimized.py @@ -523,7 +523,52 @@ async def semantic_search( logger.error("Search failed", repo_id=repo_id, error=str(e)) metrics.increment("search_errors") return [] - + + async def search_v2( + self, + query: str, + repo_id: str, + top_k: int = 10, + use_reranking: bool = True, + ) -> List[Dict]: + """Hybrid search with BM25 fusion and Cohere reranking.""" + from services.search_v2 import HybridSearcher + + start_time = time.time() + metrics.increment("search_v2_requests") + + try: + searcher = HybridSearcher( + pinecone_index=self.index, + embedding_fn=lambda q: self._create_embeddings_batch([q]).then(lambda x: x[0]), + ) + + # wrapper for async embed + async def embed(q): + embs = await self._create_embeddings_batch([q]) + return embs[0] + + searcher.embed = embed + + results = await searcher.search( + query=query, + repo_id=repo_id, + top_k=top_k, + use_reranking=use_reranking, + ) + + elapsed = time.time() - start_time + logger.info("Search V2 complete", repo_id=repo_id, results=len(results), duration_ms=round(elapsed*1000)) + metrics.timing("search_v2_latency_ms", elapsed * 1000) + + return [r.to_dict() for r in results] + + except Exception as e: + capture_exception(e, operation="search_v2", repo_id=repo_id, query=query[:100]) + logger.error("Search V2 failed", error=str(e)) + metrics.increment("search_v2_errors") + return [] + async def explain_code( self, repo_id: str, diff --git a/backend/services/search_v2/__init__.py b/backend/services/search_v2/__init__.py index 948b235..09d39c3 100644 --- a/backend/services/search_v2/__init__.py +++ b/backend/services/search_v2/__init__.py @@ -3,6 +3,7 @@ from services.search_v2.tree_sitter_extractor import TreeSitterExtractor from services.search_v2.function_filter import FunctionFilter, filter_functions from services.search_v2.summary_generator import SummaryGenerator, generate_summaries +from services.search_v2.hybrid_searcher import HybridSearcher __all__ = [ "ExtractedFunction", @@ -13,4 +14,5 @@ "filter_functions", "SummaryGenerator", "generate_summaries", + "HybridSearcher", ] diff --git a/backend/services/search_v2/hybrid_searcher.py b/backend/services/search_v2/hybrid_searcher.py new file mode 100644 index 0000000..01abf9d --- /dev/null +++ b/backend/services/search_v2/hybrid_searcher.py @@ -0,0 +1,192 @@ +"""Hybrid search with BM25 + semantic fusion and Cohere reranking.""" +import os +from typing import List, Dict, Optional +from dataclasses import dataclass + +import cohere +from rank_bm25 import BM25Okapi + +from services.search_v2.types import SearchResult +from services.observability import logger + + +@dataclass +class ScoredResult: + """Intermediate result with multiple scores.""" + metadata: Dict + semantic_score: float = 0.0 + bm25_score: float = 0.0 + rerank_score: float = 0.0 + fused_score: float = 0.0 + + +class HybridSearcher: + """Combines BM25 keyword search with semantic search and reranking.""" + + def __init__( + self, + pinecone_index, + embedding_fn, + cohere_api_key: Optional[str] = None, + rerank_model: str = "rerank-v3.5", + ): + self.index = pinecone_index + self.embed = embedding_fn + self.rerank_model = rerank_model + + api_key = cohere_api_key or os.getenv("COHERE_API_KEY") + self.cohere = cohere.Client(api_key) if api_key else None + + if not self.cohere: + logger.warning("Cohere API key not set, reranking disabled") + + async def search( + self, + query: str, + repo_id: str, + top_k: int = 10, + semantic_weight: float = 0.7, + bm25_weight: float = 0.3, + use_reranking: bool = True, + ) -> List[SearchResult]: + """ + Hybrid search with optional reranking. + + 1. Fetch candidates via semantic search (top 50) + 2. Apply BM25 scoring on candidates + 3. Fuse scores using RRF + 4. Rerank top results with Cohere + """ + # get semantic candidates + candidates = await self._semantic_search(query, repo_id, top_k=50) + if not candidates: + return [] + + # apply bm25 on candidates + candidates = self._apply_bm25(query, candidates) + + # fuse scores + candidates = self._rrf_fusion(candidates, semantic_weight, bm25_weight) + + # sort by fused score + candidates.sort(key=lambda x: x.fused_score, reverse=True) + + # rerank top results + top_candidates = candidates[:top_k * 2] + if use_reranking and self.cohere: + top_candidates = await self._rerank(query, top_candidates) + + # convert to SearchResult + return [self._to_search_result(c) for c in top_candidates[:top_k]] + + async def _semantic_search(self, query: str, repo_id: str, top_k: int) -> List[ScoredResult]: + """Fetch candidates from Pinecone.""" + query_embedding = await self.embed(query) + + results = self.index.query( + vector=query_embedding, + top_k=top_k, + include_metadata=True, + filter={"repo_id": {"$eq": repo_id}} + ) + + return [ + ScoredResult( + metadata=match.metadata, + semantic_score=match.score, + ) + for match in results.matches + ] + + def _apply_bm25(self, query: str, candidates: List[ScoredResult]) -> List[ScoredResult]: + """Score candidates with BM25.""" + if not candidates: + return candidates + + # build corpus from candidates + corpus = [] + for c in candidates: + text = f"{c.metadata.get('name', '')} {c.metadata.get('qualified_name', '')} " + text += f"{c.metadata.get('signature', '')} {c.metadata.get('docstring', '')} " + text += c.metadata.get('summary', '') + corpus.append(text.lower().split()) + + bm25 = BM25Okapi(corpus) + query_tokens = query.lower().split() + scores = bm25.get_scores(query_tokens) + + # normalize scores + max_score = max(scores) if max(scores) > 0 else 1 + for i, c in enumerate(candidates): + c.bm25_score = scores[i] / max_score + + return candidates + + def _rrf_fusion( + self, + candidates: List[ScoredResult], + semantic_weight: float, + bm25_weight: float, + k: int = 60 + ) -> List[ScoredResult]: + """Reciprocal Rank Fusion.""" + # sort by semantic for ranking + by_semantic = sorted(candidates, key=lambda x: x.semantic_score, reverse=True) + for rank, c in enumerate(by_semantic): + c.fused_score = semantic_weight / (k + rank + 1) + + # sort by bm25 for ranking + by_bm25 = sorted(candidates, key=lambda x: x.bm25_score, reverse=True) + for rank, c in enumerate(by_bm25): + c.fused_score += bm25_weight / (k + rank + 1) + + return candidates + + async def _rerank(self, query: str, candidates: List[ScoredResult]) -> List[ScoredResult]: + """Rerank with Cohere.""" + if not candidates: + return candidates + + docs = [] + for c in candidates: + doc = f"{c.metadata.get('qualified_name', '')}: {c.metadata.get('summary', '')}" + if not c.metadata.get('summary'): + doc = f"{c.metadata.get('qualified_name', '')}: {c.metadata.get('signature', '')}" + docs.append(doc) + + try: + response = self.cohere.rerank( + query=query, + documents=docs, + model=self.rerank_model, + top_n=len(candidates), + ) + + reranked = [] + for r in response.results: + c = candidates[r.index] + c.rerank_score = r.relevance_score + reranked.append(c) + + return reranked + + except Exception as e: + logger.warning("Reranking failed", error=str(e)) + return candidates + + def _to_search_result(self, scored: ScoredResult) -> SearchResult: + """Convert to SearchResult.""" + m = scored.metadata + return SearchResult( + name=m.get("name", ""), + qualified_name=m.get("qualified_name", ""), + file_path=m.get("file_path", ""), + code=m.get("code", ""), + signature=m.get("signature", ""), + language=m.get("language", ""), + score=scored.rerank_score if scored.rerank_score else scored.fused_score, + start_line=m.get("start_line", 0), + end_line=m.get("end_line", 0), + summary=m.get("summary"), + class_name=m.get("class_name"), + ) diff --git a/backend/tests/test_hybrid_searcher.py b/backend/tests/test_hybrid_searcher.py new file mode 100644 index 0000000..97d1521 --- /dev/null +++ b/backend/tests/test_hybrid_searcher.py @@ -0,0 +1,162 @@ +"""Tests for hybrid search and reranking.""" +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from dataclasses import dataclass + +from services.search_v2.hybrid_searcher import HybridSearcher, ScoredResult + + +@dataclass +class MockMatch: + metadata: dict + score: float + + +@dataclass +class MockQueryResult: + matches: list + + +class TestHybridSearcher: + + @pytest.fixture + def mock_index(self): + index = MagicMock() + index.query = MagicMock(return_value=MockQueryResult(matches=[ + MockMatch(metadata={ + "name": "fetch_users", + "qualified_name": "UserService.fetch_users", + "signature": "def fetch_users(self, limit: int):", + "code": "def fetch_users(self, limit): return self.db.query()", + "file_path": "services/user.py", + "language": "python", + "start_line": 10, + "end_line": 15, + "summary": "Fetches users from database with limit", + }, score=0.92), + MockMatch(metadata={ + "name": "get_user", + "qualified_name": "UserService.get_user", + "signature": "def get_user(self, user_id: int):", + "code": "def get_user(self, user_id): return self.db.get(user_id)", + "file_path": "services/user.py", + "language": "python", + "start_line": 20, + "end_line": 25, + "summary": "Gets single user by ID", + }, score=0.85), + MockMatch(metadata={ + "name": "delete_user", + "qualified_name": "UserService.delete_user", + "signature": "def delete_user(self, user_id: int):", + "code": "def delete_user(self, user_id): self.db.delete(user_id)", + "file_path": "services/user.py", + "language": "python", + "start_line": 30, + "end_line": 35, + "summary": "Deletes user from database", + }, score=0.78), + ])) + return index + + @pytest.fixture + def searcher(self, mock_index): + async def mock_embed(query): + return [0.1] * 1024 + + searcher = HybridSearcher( + pinecone_index=mock_index, + embedding_fn=mock_embed, + cohere_api_key=None, # disable reranking for basic tests + ) + searcher.embed = mock_embed + return searcher + + @pytest.mark.asyncio + async def test_search_returns_results(self, searcher): + results = await searcher.search("fetch users", "repo-123", top_k=3, use_reranking=False) + + assert len(results) == 3 + assert results[0].name == "fetch_users" + assert results[0].qualified_name == "UserService.fetch_users" + + @pytest.mark.asyncio + async def test_search_applies_bm25(self, searcher): + # "fetch users" should boost fetch_users over others + results = await searcher.search("fetch users", "repo-123", top_k=3, use_reranking=False) + + # fetch_users should be top due to BM25 keyword match + assert results[0].name == "fetch_users" + + @pytest.mark.asyncio + async def test_search_empty_results(self, mock_index): + mock_index.query.return_value = MockQueryResult(matches=[]) + + async def mock_embed(q): + return [0.1] * 1024 + + searcher = HybridSearcher(mock_index, mock_embed) + searcher.embed = mock_embed + + results = await searcher.search("nonexistent", "repo-123", use_reranking=False) + assert results == [] + + @pytest.mark.asyncio + async def test_rrf_fusion(self, searcher): + candidates = [ + ScoredResult(metadata={"name": "a"}, semantic_score=0.9, bm25_score=0.3), + ScoredResult(metadata={"name": "b"}, semantic_score=0.7, bm25_score=0.9), + ScoredResult(metadata={"name": "c"}, semantic_score=0.5, bm25_score=0.5), + ] + + fused = searcher._rrf_fusion(candidates, semantic_weight=0.6, bm25_weight=0.4) + + # all should have fused scores + assert all(c.fused_score > 0 for c in fused) + + @pytest.mark.asyncio + async def test_search_with_reranking(self, mock_index): + async def mock_embed(q): + return [0.1] * 1024 + + mock_cohere = MagicMock() + mock_cohere.rerank.return_value = MagicMock(results=[ + MagicMock(index=1, relevance_score=0.95), # get_user now top + MagicMock(index=0, relevance_score=0.90), + MagicMock(index=2, relevance_score=0.70), + ]) + + with patch('services.search_v2.hybrid_searcher.cohere.Client', return_value=mock_cohere): + searcher = HybridSearcher(mock_index, mock_embed, cohere_api_key="test-key") + searcher.embed = mock_embed + + results = await searcher.search("get user by id", "repo-123", top_k=3, use_reranking=True) + + # reranking should reorder results + assert results[0].name == "get_user" + assert results[0].score == 0.95 + + @pytest.mark.asyncio + async def test_rerank_handles_error(self, searcher): + # searcher has no cohere client, should gracefully skip reranking + results = await searcher.search("test", "repo-123", top_k=2, use_reranking=True) + assert len(results) > 0 + + +class TestScoredResult: + + def test_default_scores(self): + result = ScoredResult(metadata={"name": "test"}) + assert result.semantic_score == 0.0 + assert result.bm25_score == 0.0 + assert result.rerank_score == 0.0 + assert result.fused_score == 0.0 + + def test_with_scores(self): + result = ScoredResult( + metadata={"name": "test"}, + semantic_score=0.9, + bm25_score=0.5, + ) + assert result.semantic_score == 0.9 + assert result.bm25_score == 0.5