Skip to content

Commit 68db72d

Browse files
author
Mahika Wason
committed
Subquery generator AbstractOperator init
1 parent 80f254f commit 68db72d

2 files changed

Lines changed: 333 additions & 0 deletions

File tree

nemo_retriever/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ dependencies = [
8383
svg = [
8484
"cairosvg>=2.7.0",
8585
]
86+
llm = [
87+
"openai>=1.0",
88+
]
8689
dev = [
8790
"build>=1.2.2",
8891
"pytest>=8.0.2",
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2+
# All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Operator that expands a query DataFrame into sub-queries via an LLM."""
6+
7+
from __future__ import annotations
8+
9+
import json
10+
import os
11+
from typing import Any, List, Literal, Optional
12+
13+
import pandas as pd
14+
15+
from nemo_retriever.graph.abstract_operator import AbstractOperator
16+
from nemo_retriever.graph.cpu_operator import CPUOperator
17+
18+
# openai is imported lazily in _ensure_client() so the operator stays
19+
# serialisable for Ray workers without requiring it at import time.
20+
21+
22+
# ---------------------------------------------------------------------------
23+
# Built-in strategy prompts
24+
# ---------------------------------------------------------------------------
25+
26+
_PROMPTS: dict[str, str] = {
27+
"decompose": """\
28+
You are a query decomposition assistant for a retrieval system.
29+
30+
Given a search query, break it down into up to {max_subqueries} distinct sub-queries that \
31+
together cover all aspects and angles of the original query. Generate as many sub-queries as \
32+
are genuinely useful — do not pad with redundant ones just to hit the maximum. Each sub-query \
33+
should target a specific facet, making it easier for a dense retrieval system to find all \
34+
relevant documents.
35+
36+
Rules:
37+
- Each sub-query must be self-contained and meaningful on its own.
38+
- Sub-queries should be diverse and complementary, not redundant.
39+
- Use clear, precise language suited for dense embedding retrieval.
40+
- Output a JSON array of strings only — no explanation, no markdown fences.""",
41+
42+
"hyde": """\
43+
You are a Hypothetical Document Embedding (HyDE) assistant for a retrieval system.
44+
45+
Given a search query, generate up to {max_subqueries} short hypothetical document passages \
46+
(2–4 sentences each) that would directly answer or address the query. Generate as many as are \
47+
genuinely useful — fewer is fine if the query is simple. These passages will be used as queries \
48+
to a dense retrieval system to find real, similar documents.
49+
50+
Rules:
51+
- Each passage should read like a real document excerpt that answers the query.
52+
- Vary the style and perspective across passages (e.g., academic, technical, narrative).
53+
- Be factually plausible; focus on covering the query intent.
54+
- Output a JSON array of strings only — no explanation, no markdown fences.""",
55+
56+
"multi_perspective": """\
57+
You are a multi-perspective query expansion assistant for a retrieval system.
58+
59+
Given a search query, generate up to {max_subqueries} reformulations from different angles, \
60+
perspectives, or levels of specificity to maximise recall in a dense retrieval system. Only \
61+
generate reformulations that add genuine coverage — do not pad.
62+
63+
Rules:
64+
- Vary terminology: use synonyms, technical vs. casual language, acronyms vs. full names.
65+
- Vary scope: broad overview queries alongside narrow, specific ones.
66+
- Vary form: declarative statements, questions, and entity-focused queries.
67+
- Each reformulation must have a meaningfully different surface form from the others.
68+
- Output a JSON array of strings only — no explanation, no markdown fences.""",
69+
}
70+
71+
72+
# ---------------------------------------------------------------------------
73+
# Operator
74+
# ---------------------------------------------------------------------------
75+
76+
77+
class SubQueryGeneratorOperator(AbstractOperator, CPUOperator):
78+
"""Expand each query row into sub-query rows using an LLM.
79+
80+
The operator calls an LLM (via the OpenAI SDK) once per input query and
81+
explodes the result into one output row per generated sub-query. The LLM
82+
decides how many sub-queries to generate up to ``max_subqueries``. This
83+
makes it a natural upstream stage for a retrieval operator: the downstream
84+
operator can retrieve documents independently for every sub-query row, and
85+
a subsequent aggregation step (e.g. RRF) can merge the per-sub-query
86+
ranked lists back into a single ranking per ``query_id``.
87+
88+
Input DataFrame schema
89+
----------------------
90+
query_id : str — unique identifier for the query
91+
query_text : str — the search query text
92+
(any additional columns are passed through unchanged)
93+
94+
Output DataFrame schema
95+
-----------------------
96+
query_id : str — same ``query_id`` as the input row
97+
query_text : str — original query text (preserved for context)
98+
subquery_idx : int — 0-based position within the generated sub-query group
99+
subquery_text : str — the generated sub-query text
100+
(additional input columns are forwarded to every expanded row)
101+
102+
Parameters
103+
----------
104+
llm_model : str
105+
OpenAI model identifier, e.g. ``"gpt-4o"``.
106+
max_subqueries : int
107+
Maximum number of sub-queries the LLM may generate per query.
108+
The LLM will generate fewer if the query does not warrant the maximum.
109+
Defaults to ``4``.
110+
strategy : {"decompose", "hyde", "multi_perspective"}
111+
Built-in sub-query generation strategy.
112+
113+
``"decompose"``
114+
Break the query into complementary sub-aspects (default).
115+
``"hyde"``
116+
Generate hypothetical answer passages (HyDE).
117+
``"multi_perspective"``
118+
Rewrite the query from diverse angles to maximise recall.
119+
api_key : str, optional
120+
Literal API key **or** an ``"os.environ/VAR_NAME"`` reference that is
121+
resolved at call time. Omit to rely on the ``OPENAI_API_KEY``
122+
environment variable.
123+
base_url : str, optional
124+
Custom endpoint URL — useful for NIM deployments or self-hosted models.
125+
max_tokens : int, optional
126+
Upper bound on tokens in the LLM response.
127+
system_prompt_override : str, optional
128+
Fully custom system prompt. Use ``{max_subqueries}`` as a placeholder.
129+
When provided, ``strategy`` is ignored.
130+
131+
Examples
132+
--------
133+
Standalone use::
134+
135+
import pandas as pd
136+
from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator
137+
138+
op = SubQueryGeneratorOperator(llm_model="gpt-4o", max_subqueries=5)
139+
df = pd.DataFrame({
140+
"query_id": ["q1", "q2"],
141+
"query_text": ["What causes inflation?", "How do vaccines work?"],
142+
})
143+
result = op.run(df)
144+
# result has up to 10 rows: ≤5 sub-queries × 2 original queries
145+
146+
Composing into a graph::
147+
148+
from nemo_retriever.graph import InprocessExecutor
149+
from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator
150+
151+
graph = (
152+
SubQueryGeneratorOperator(llm_model="gpt-4o", max_subqueries=4)
153+
>> RetrievalOperator(retriever=my_retriever)
154+
>> RRFAggregatorOperator()
155+
)
156+
executor = InprocessExecutor(graph)
157+
result_df = executor.ingest(query_df)
158+
"""
159+
160+
def __init__(
161+
self,
162+
*,
163+
llm_model: str,
164+
max_subqueries: int = 4,
165+
strategy: Literal["decompose", "hyde", "multi_perspective"] = "decompose",
166+
api_key: Optional[str] = None,
167+
base_url: Optional[str] = None,
168+
max_tokens: Optional[int] = None,
169+
system_prompt_override: Optional[str] = None,
170+
) -> None:
171+
super().__init__()
172+
self._llm_model = llm_model
173+
self._max_subqueries = max_subqueries
174+
self._strategy = strategy
175+
self._api_key = api_key
176+
self._base_url = base_url
177+
self._max_tokens = max_tokens
178+
self._system_prompt_override = system_prompt_override
179+
180+
# OpenAI client is created lazily so the operator stays serialisable for Ray.
181+
self._client: Any = None
182+
183+
# ------------------------------------------------------------------
184+
# AbstractOperator interface
185+
# ------------------------------------------------------------------
186+
187+
def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame:
188+
"""Normalise *data* to a DataFrame with ``query_id`` / ``query_text`` columns.
189+
190+
Accepted input types
191+
--------------------
192+
``pd.DataFrame``
193+
Must contain at least ``query_id`` and ``query_text`` columns.
194+
``list[str]``
195+
Plain query strings; ``query_id`` values are auto-assigned as
196+
``"q0"``, ``"q1"``, …
197+
``list[tuple[str, str]]`` or ``list[list[str, str]]``
198+
``(query_id, query_text)`` pairs.
199+
"""
200+
if isinstance(data, pd.DataFrame):
201+
missing = {"query_id", "query_text"} - set(data.columns)
202+
if missing:
203+
raise ValueError(
204+
f"Input DataFrame is missing required column(s): {sorted(missing)}. "
205+
"Expected at minimum: 'query_id' and 'query_text'."
206+
)
207+
return data.copy()
208+
209+
if isinstance(data, list) and data:
210+
first = data[0]
211+
if isinstance(first, str):
212+
return pd.DataFrame(
213+
{
214+
"query_id": [f"q{i}" for i in range(len(data))],
215+
"query_text": list(data),
216+
}
217+
)
218+
if isinstance(first, (tuple, list)) and len(first) == 2:
219+
return pd.DataFrame(data, columns=["query_id", "query_text"])
220+
221+
raise TypeError(
222+
f"Unsupported input type {type(data).__name__!r}. "
223+
"Pass a pd.DataFrame with 'query_id' and 'query_text' columns, "
224+
"a list[str], or a list[tuple[str, str]]."
225+
)
226+
227+
def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
228+
"""Generate sub-queries for every row and explode to one row per sub-query."""
229+
self._ensure_client()
230+
system_prompt = self._build_system_prompt()
231+
232+
passthrough_cols = [c for c in data.columns if c not in ("query_id", "query_text")]
233+
rows: List[dict[str, Any]] = []
234+
235+
for _, row in data.iterrows():
236+
subqueries = self._generate_one(row["query_text"], system_prompt)
237+
for idx, sq in enumerate(subqueries):
238+
new_row: dict[str, Any] = {
239+
"query_id": row["query_id"],
240+
"query_text": row["query_text"],
241+
"subquery_idx": idx,
242+
"subquery_text": sq,
243+
}
244+
for col in passthrough_cols:
245+
new_row[col] = row[col]
246+
rows.append(new_row)
247+
248+
if not rows:
249+
return pd.DataFrame(columns=["query_id", "query_text", "subquery_idx", "subquery_text"])
250+
251+
return pd.DataFrame(rows)
252+
253+
def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
254+
return data
255+
256+
# ------------------------------------------------------------------
257+
# Internal helpers
258+
# ------------------------------------------------------------------
259+
260+
def _ensure_client(self) -> None:
261+
"""Lazily create the OpenAI client (once per instance)."""
262+
if self._client is not None:
263+
return
264+
try:
265+
from openai import OpenAI
266+
except ImportError as exc:
267+
raise ImportError(
268+
"SubQueryGeneratorOperator requires 'openai'. "
269+
"Install it with: pip install 'openai>=1.0'"
270+
) from exc
271+
272+
api_key = self._api_key
273+
if api_key is not None and api_key.strip().startswith("os.environ/"):
274+
var = api_key.strip().removeprefix("os.environ/")
275+
api_key = os.environ[var]
276+
277+
self._client = OpenAI(
278+
api_key=api_key,
279+
**({"base_url": self._base_url} if self._base_url is not None else {}),
280+
)
281+
282+
def _build_system_prompt(self) -> str:
283+
template = self._system_prompt_override or _PROMPTS[self._strategy]
284+
return template.format(max_subqueries=self._max_subqueries)
285+
286+
def _generate_one(self, query: str, system_prompt: str) -> List[str]:
287+
"""Call the LLM and return a list of sub-query strings for *query*."""
288+
call_kwargs: dict[str, Any] = dict(
289+
model=self._llm_model,
290+
messages=[
291+
{"role": "system", "content": system_prompt},
292+
{"role": "user", "content": f"Query: {query}"},
293+
],
294+
)
295+
if self._max_tokens is not None:
296+
call_kwargs["max_tokens"] = self._max_tokens
297+
298+
response = self._client.chat.completions.create(**call_kwargs)
299+
raw = response.choices[0].message.content.strip()
300+
return _parse_json_list(raw, fallback=query)
301+
302+
303+
# ---------------------------------------------------------------------------
304+
# Module-level helpers (no instance state — easier to test in isolation)
305+
# ---------------------------------------------------------------------------
306+
307+
308+
def _parse_json_list(raw: str, *, fallback: str) -> List[str]:
309+
"""Parse a JSON array from *raw*, stripping markdown fences if present.
310+
311+
Returns *[fallback]* when parsing fails so downstream stages always
312+
receive at least one sub-query row.
313+
"""
314+
text = raw
315+
for fence in ("```json", "```"):
316+
if text.startswith(fence):
317+
text = text[len(fence):]
318+
break
319+
if text.endswith("```"):
320+
text = text[:-3]
321+
text = text.strip()
322+
323+
try:
324+
parsed = json.loads(text)
325+
if isinstance(parsed, list) and parsed and all(isinstance(s, str) for s in parsed):
326+
return parsed
327+
except json.JSONDecodeError:
328+
pass
329+
330+
return [fallback]

0 commit comments

Comments
 (0)