Skip to content

Commit 97d3913

Browse files
authored
Merge pull request #140 from DevanshuNEU/feat/hybrid-search-reranking
feat(search-v2): add hybrid search with BM25 fusion and Cohere reranking
2 parents 961e8c0 + 421a625 commit 97d3913

5 files changed

Lines changed: 406 additions & 1 deletion

File tree

backend/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,7 @@ pytest-cov>=6.0.0
3939

4040
# Observability
4141
sentry-sdk[fastapi]>=2.0.0
42+
43+
# Search V2 - Hybrid search
44+
rank-bm25>=0.2.2
45+
cohere>=5.0.0

backend/services/indexer_optimized.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,52 @@ async def semantic_search(
523523
logger.error("Search failed", repo_id=repo_id, error=str(e))
524524
metrics.increment("search_errors")
525525
return []
526-
526+
527+
async def search_v2(
528+
self,
529+
query: str,
530+
repo_id: str,
531+
top_k: int = 10,
532+
use_reranking: bool = True,
533+
) -> List[Dict]:
534+
"""Hybrid search with BM25 fusion and Cohere reranking."""
535+
from services.search_v2 import HybridSearcher
536+
537+
start_time = time.time()
538+
metrics.increment("search_v2_requests")
539+
540+
try:
541+
searcher = HybridSearcher(
542+
pinecone_index=self.index,
543+
embedding_fn=lambda q: self._create_embeddings_batch([q]).then(lambda x: x[0]),
544+
)
545+
546+
# wrapper for async embed
547+
async def embed(q):
548+
embs = await self._create_embeddings_batch([q])
549+
return embs[0]
550+
551+
searcher.embed = embed
552+
553+
results = await searcher.search(
554+
query=query,
555+
repo_id=repo_id,
556+
top_k=top_k,
557+
use_reranking=use_reranking,
558+
)
559+
560+
elapsed = time.time() - start_time
561+
logger.info("Search V2 complete", repo_id=repo_id, results=len(results), duration_ms=round(elapsed*1000))
562+
metrics.timing("search_v2_latency_ms", elapsed * 1000)
563+
564+
return [r.to_dict() for r in results]
565+
566+
except Exception as e:
567+
capture_exception(e, operation="search_v2", repo_id=repo_id, query=query[:100])
568+
logger.error("Search V2 failed", error=str(e))
569+
metrics.increment("search_v2_errors")
570+
return []
571+
527572
async def explain_code(
528573
self,
529574
repo_id: str,

backend/services/search_v2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from services.search_v2.tree_sitter_extractor import TreeSitterExtractor
44
from services.search_v2.function_filter import FunctionFilter, filter_functions
55
from services.search_v2.summary_generator import SummaryGenerator, generate_summaries
6+
from services.search_v2.hybrid_searcher import HybridSearcher
67

78
__all__ = [
89
"ExtractedFunction",
@@ -13,4 +14,5 @@
1314
"filter_functions",
1415
"SummaryGenerator",
1516
"generate_summaries",
17+
"HybridSearcher",
1618
]
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Hybrid search with BM25 + semantic fusion and Cohere reranking."""
2+
import os
3+
from typing import List, Dict, Optional
4+
from dataclasses import dataclass
5+
6+
import cohere
7+
from rank_bm25 import BM25Okapi
8+
9+
from services.search_v2.types import SearchResult
10+
from services.observability import logger
11+
12+
13+
@dataclass
14+
class ScoredResult:
15+
"""Intermediate result with multiple scores."""
16+
metadata: Dict
17+
semantic_score: float = 0.0
18+
bm25_score: float = 0.0
19+
rerank_score: float = 0.0
20+
fused_score: float = 0.0
21+
22+
23+
class HybridSearcher:
24+
"""Combines BM25 keyword search with semantic search and reranking."""
25+
26+
def __init__(
27+
self,
28+
pinecone_index,
29+
embedding_fn,
30+
cohere_api_key: Optional[str] = None,
31+
rerank_model: str = "rerank-v3.5",
32+
):
33+
self.index = pinecone_index
34+
self.embed = embedding_fn
35+
self.rerank_model = rerank_model
36+
37+
api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
38+
self.cohere = cohere.Client(api_key) if api_key else None
39+
40+
if not self.cohere:
41+
logger.warning("Cohere API key not set, reranking disabled")
42+
43+
async def search(
44+
self,
45+
query: str,
46+
repo_id: str,
47+
top_k: int = 10,
48+
semantic_weight: float = 0.7,
49+
bm25_weight: float = 0.3,
50+
use_reranking: bool = True,
51+
) -> List[SearchResult]:
52+
"""
53+
Hybrid search with optional reranking.
54+
55+
1. Fetch candidates via semantic search (top 50)
56+
2. Apply BM25 scoring on candidates
57+
3. Fuse scores using RRF
58+
4. Rerank top results with Cohere
59+
"""
60+
# get semantic candidates
61+
candidates = await self._semantic_search(query, repo_id, top_k=50)
62+
if not candidates:
63+
return []
64+
65+
# apply bm25 on candidates
66+
candidates = self._apply_bm25(query, candidates)
67+
68+
# fuse scores
69+
candidates = self._rrf_fusion(candidates, semantic_weight, bm25_weight)
70+
71+
# sort by fused score
72+
candidates.sort(key=lambda x: x.fused_score, reverse=True)
73+
74+
# rerank top results
75+
top_candidates = candidates[:top_k * 2]
76+
if use_reranking and self.cohere:
77+
top_candidates = await self._rerank(query, top_candidates)
78+
79+
# convert to SearchResult
80+
return [self._to_search_result(c) for c in top_candidates[:top_k]]
81+
82+
async def _semantic_search(self, query: str, repo_id: str, top_k: int) -> List[ScoredResult]:
83+
"""Fetch candidates from Pinecone."""
84+
query_embedding = await self.embed(query)
85+
86+
results = self.index.query(
87+
vector=query_embedding,
88+
top_k=top_k,
89+
include_metadata=True,
90+
filter={"repo_id": {"$eq": repo_id}}
91+
)
92+
93+
return [
94+
ScoredResult(
95+
metadata=match.metadata,
96+
semantic_score=match.score,
97+
)
98+
for match in results.matches
99+
]
100+
101+
def _apply_bm25(self, query: str, candidates: List[ScoredResult]) -> List[ScoredResult]:
102+
"""Score candidates with BM25."""
103+
if not candidates:
104+
return candidates
105+
106+
# build corpus from candidates
107+
corpus = []
108+
for c in candidates:
109+
text = f"{c.metadata.get('name', '')} {c.metadata.get('qualified_name', '')} "
110+
text += f"{c.metadata.get('signature', '')} {c.metadata.get('docstring', '')} "
111+
text += c.metadata.get('summary', '')
112+
corpus.append(text.lower().split())
113+
114+
bm25 = BM25Okapi(corpus)
115+
query_tokens = query.lower().split()
116+
scores = bm25.get_scores(query_tokens)
117+
118+
# normalize scores
119+
max_score = max(scores) if max(scores) > 0 else 1
120+
for i, c in enumerate(candidates):
121+
c.bm25_score = scores[i] / max_score
122+
123+
return candidates
124+
125+
def _rrf_fusion(
126+
self,
127+
candidates: List[ScoredResult],
128+
semantic_weight: float,
129+
bm25_weight: float,
130+
k: int = 60
131+
) -> List[ScoredResult]:
132+
"""Reciprocal Rank Fusion."""
133+
# sort by semantic for ranking
134+
by_semantic = sorted(candidates, key=lambda x: x.semantic_score, reverse=True)
135+
for rank, c in enumerate(by_semantic):
136+
c.fused_score = semantic_weight / (k + rank + 1)
137+
138+
# sort by bm25 for ranking
139+
by_bm25 = sorted(candidates, key=lambda x: x.bm25_score, reverse=True)
140+
for rank, c in enumerate(by_bm25):
141+
c.fused_score += bm25_weight / (k + rank + 1)
142+
143+
return candidates
144+
145+
async def _rerank(self, query: str, candidates: List[ScoredResult]) -> List[ScoredResult]:
146+
"""Rerank with Cohere."""
147+
if not candidates:
148+
return candidates
149+
150+
docs = []
151+
for c in candidates:
152+
doc = f"{c.metadata.get('qualified_name', '')}: {c.metadata.get('summary', '')}"
153+
if not c.metadata.get('summary'):
154+
doc = f"{c.metadata.get('qualified_name', '')}: {c.metadata.get('signature', '')}"
155+
docs.append(doc)
156+
157+
try:
158+
response = self.cohere.rerank(
159+
query=query,
160+
documents=docs,
161+
model=self.rerank_model,
162+
top_n=len(candidates),
163+
)
164+
165+
reranked = []
166+
for r in response.results:
167+
c = candidates[r.index]
168+
c.rerank_score = r.relevance_score
169+
reranked.append(c)
170+
171+
return reranked
172+
173+
except Exception as e:
174+
logger.warning("Reranking failed", error=str(e))
175+
return candidates
176+
177+
def _to_search_result(self, scored: ScoredResult) -> SearchResult:
178+
"""Convert to SearchResult."""
179+
m = scored.metadata
180+
return SearchResult(
181+
name=m.get("name", ""),
182+
qualified_name=m.get("qualified_name", ""),
183+
file_path=m.get("file_path", ""),
184+
code=m.get("code", ""),
185+
signature=m.get("signature", ""),
186+
language=m.get("language", ""),
187+
score=scored.rerank_score if scored.rerank_score else scored.fused_score,
188+
start_line=m.get("start_line", 0),
189+
end_line=m.get("end_line", 0),
190+
summary=m.get("summary"),
191+
class_name=m.get("class_name"),
192+
)

0 commit comments

Comments
 (0)