|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. |
| 2 | +# All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +"""Operator that expands a query DataFrame into sub-queries via an LLM.""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +import json |
| 10 | +import os |
| 11 | +from typing import Any, List, Literal, Optional |
| 12 | + |
| 13 | +import pandas as pd |
| 14 | + |
| 15 | +from nemo_retriever.graph.abstract_operator import AbstractOperator |
| 16 | +from nemo_retriever.graph.cpu_operator import CPUOperator |
| 17 | + |
| 18 | +# openai is imported lazily in _ensure_client() so the operator stays |
| 19 | +# serialisable for Ray workers without requiring it at import time. |
| 20 | + |
| 21 | + |
| 22 | +# --------------------------------------------------------------------------- |
| 23 | +# Built-in strategy prompts |
| 24 | +# --------------------------------------------------------------------------- |
| 25 | + |
| 26 | +_PROMPTS: dict[str, str] = { |
| 27 | + "decompose": """\ |
| 28 | +You are a query decomposition assistant for a retrieval system. |
| 29 | +
|
| 30 | +Given a search query, break it down into up to {max_subqueries} distinct sub-queries that \ |
| 31 | +together cover all aspects and angles of the original query. Generate as many sub-queries as \ |
| 32 | +are genuinely useful — do not pad with redundant ones just to hit the maximum. Each sub-query \ |
| 33 | +should target a specific facet, making it easier for a dense retrieval system to find all \ |
| 34 | +relevant documents. |
| 35 | +
|
| 36 | +Rules: |
| 37 | +- Each sub-query must be self-contained and meaningful on its own. |
| 38 | +- Sub-queries should be diverse and complementary, not redundant. |
| 39 | +- Use clear, precise language suited for dense embedding retrieval. |
| 40 | +- Output a JSON array of strings only — no explanation, no markdown fences.""", |
| 41 | + |
| 42 | + "hyde": """\ |
| 43 | +You are a Hypothetical Document Embedding (HyDE) assistant for a retrieval system. |
| 44 | +
|
| 45 | +Given a search query, generate up to {max_subqueries} short hypothetical document passages \ |
| 46 | +(2–4 sentences each) that would directly answer or address the query. Generate as many as are \ |
| 47 | +genuinely useful — fewer is fine if the query is simple. These passages will be used as queries \ |
| 48 | +to a dense retrieval system to find real, similar documents. |
| 49 | +
|
| 50 | +Rules: |
| 51 | +- Each passage should read like a real document excerpt that answers the query. |
| 52 | +- Vary the style and perspective across passages (e.g., academic, technical, narrative). |
| 53 | +- Be factually plausible; focus on covering the query intent. |
| 54 | +- Output a JSON array of strings only — no explanation, no markdown fences.""", |
| 55 | + |
| 56 | + "multi_perspective": """\ |
| 57 | +You are a multi-perspective query expansion assistant for a retrieval system. |
| 58 | +
|
| 59 | +Given a search query, generate up to {max_subqueries} reformulations from different angles, \ |
| 60 | +perspectives, or levels of specificity to maximise recall in a dense retrieval system. Only \ |
| 61 | +generate reformulations that add genuine coverage — do not pad. |
| 62 | +
|
| 63 | +Rules: |
| 64 | +- Vary terminology: use synonyms, technical vs. casual language, acronyms vs. full names. |
| 65 | +- Vary scope: broad overview queries alongside narrow, specific ones. |
| 66 | +- Vary form: declarative statements, questions, and entity-focused queries. |
| 67 | +- Each reformulation must have a meaningfully different surface form from the others. |
| 68 | +- Output a JSON array of strings only — no explanation, no markdown fences.""", |
| 69 | +} |
| 70 | + |
| 71 | + |
| 72 | +# --------------------------------------------------------------------------- |
| 73 | +# Operator |
| 74 | +# --------------------------------------------------------------------------- |
| 75 | + |
| 76 | + |
| 77 | +class SubQueryGeneratorOperator(AbstractOperator, CPUOperator): |
| 78 | + """Expand each query row into sub-query rows using an LLM. |
| 79 | +
|
| 80 | + The operator calls an LLM (via the OpenAI SDK) once per input query and |
| 81 | + explodes the result into one output row per generated sub-query. The LLM |
| 82 | + decides how many sub-queries to generate up to ``max_subqueries``. This |
| 83 | + makes it a natural upstream stage for a retrieval operator: the downstream |
| 84 | + operator can retrieve documents independently for every sub-query row, and |
| 85 | + a subsequent aggregation step (e.g. RRF) can merge the per-sub-query |
| 86 | + ranked lists back into a single ranking per ``query_id``. |
| 87 | +
|
| 88 | + Input DataFrame schema |
| 89 | + ---------------------- |
| 90 | + query_id : str — unique identifier for the query |
| 91 | + query_text : str — the search query text |
| 92 | + (any additional columns are passed through unchanged) |
| 93 | +
|
| 94 | + Output DataFrame schema |
| 95 | + ----------------------- |
| 96 | + query_id : str — same ``query_id`` as the input row |
| 97 | + query_text : str — original query text (preserved for context) |
| 98 | + subquery_idx : int — 0-based position within the generated sub-query group |
| 99 | + subquery_text : str — the generated sub-query text |
| 100 | + (additional input columns are forwarded to every expanded row) |
| 101 | +
|
| 102 | + Parameters |
| 103 | + ---------- |
| 104 | + llm_model : str |
| 105 | + OpenAI model identifier, e.g. ``"gpt-4o"``. |
| 106 | + max_subqueries : int |
| 107 | + Maximum number of sub-queries the LLM may generate per query. |
| 108 | + The LLM will generate fewer if the query does not warrant the maximum. |
| 109 | + Defaults to ``4``. |
| 110 | + strategy : {"decompose", "hyde", "multi_perspective"} |
| 111 | + Built-in sub-query generation strategy. |
| 112 | +
|
| 113 | + ``"decompose"`` |
| 114 | + Break the query into complementary sub-aspects (default). |
| 115 | + ``"hyde"`` |
| 116 | + Generate hypothetical answer passages (HyDE). |
| 117 | + ``"multi_perspective"`` |
| 118 | + Rewrite the query from diverse angles to maximise recall. |
| 119 | + api_key : str, optional |
| 120 | + Literal API key **or** an ``"os.environ/VAR_NAME"`` reference that is |
| 121 | + resolved at call time. Omit to rely on the ``OPENAI_API_KEY`` |
| 122 | + environment variable. |
| 123 | + base_url : str, optional |
| 124 | + Custom endpoint URL — useful for NIM deployments or self-hosted models. |
| 125 | + max_tokens : int, optional |
| 126 | + Upper bound on tokens in the LLM response. |
| 127 | + system_prompt_override : str, optional |
| 128 | + Fully custom system prompt. Use ``{max_subqueries}`` as a placeholder. |
| 129 | + When provided, ``strategy`` is ignored. |
| 130 | +
|
| 131 | + Examples |
| 132 | + -------- |
| 133 | + Standalone use:: |
| 134 | +
|
| 135 | + import pandas as pd |
| 136 | + from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator |
| 137 | +
|
| 138 | + op = SubQueryGeneratorOperator(llm_model="gpt-4o", max_subqueries=5) |
| 139 | + df = pd.DataFrame({ |
| 140 | + "query_id": ["q1", "q2"], |
| 141 | + "query_text": ["What causes inflation?", "How do vaccines work?"], |
| 142 | + }) |
| 143 | + result = op.run(df) |
| 144 | + # result has up to 10 rows: ≤5 sub-queries × 2 original queries |
| 145 | +
|
| 146 | + Composing into a graph:: |
| 147 | +
|
| 148 | + from nemo_retriever.graph import InprocessExecutor |
| 149 | + from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator |
| 150 | +
|
| 151 | + graph = ( |
| 152 | + SubQueryGeneratorOperator(llm_model="gpt-4o", max_subqueries=4) |
| 153 | + >> RetrievalOperator(retriever=my_retriever) |
| 154 | + >> RRFAggregatorOperator() |
| 155 | + ) |
| 156 | + executor = InprocessExecutor(graph) |
| 157 | + result_df = executor.ingest(query_df) |
| 158 | + """ |
| 159 | + |
| 160 | + def __init__( |
| 161 | + self, |
| 162 | + *, |
| 163 | + llm_model: str, |
| 164 | + max_subqueries: int = 4, |
| 165 | + strategy: Literal["decompose", "hyde", "multi_perspective"] = "decompose", |
| 166 | + api_key: Optional[str] = None, |
| 167 | + base_url: Optional[str] = None, |
| 168 | + max_tokens: Optional[int] = None, |
| 169 | + system_prompt_override: Optional[str] = None, |
| 170 | + ) -> None: |
| 171 | + super().__init__() |
| 172 | + self._llm_model = llm_model |
| 173 | + self._max_subqueries = max_subqueries |
| 174 | + self._strategy = strategy |
| 175 | + self._api_key = api_key |
| 176 | + self._base_url = base_url |
| 177 | + self._max_tokens = max_tokens |
| 178 | + self._system_prompt_override = system_prompt_override |
| 179 | + |
| 180 | + # OpenAI client is created lazily so the operator stays serialisable for Ray. |
| 181 | + self._client: Any = None |
| 182 | + |
| 183 | + # ------------------------------------------------------------------ |
| 184 | + # AbstractOperator interface |
| 185 | + # ------------------------------------------------------------------ |
| 186 | + |
| 187 | + def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame: |
| 188 | + """Normalise *data* to a DataFrame with ``query_id`` / ``query_text`` columns. |
| 189 | +
|
| 190 | + Accepted input types |
| 191 | + -------------------- |
| 192 | + ``pd.DataFrame`` |
| 193 | + Must contain at least ``query_id`` and ``query_text`` columns. |
| 194 | + ``list[str]`` |
| 195 | + Plain query strings; ``query_id`` values are auto-assigned as |
| 196 | + ``"q0"``, ``"q1"``, … |
| 197 | + ``list[tuple[str, str]]`` or ``list[list[str, str]]`` |
| 198 | + ``(query_id, query_text)`` pairs. |
| 199 | + """ |
| 200 | + if isinstance(data, pd.DataFrame): |
| 201 | + missing = {"query_id", "query_text"} - set(data.columns) |
| 202 | + if missing: |
| 203 | + raise ValueError( |
| 204 | + f"Input DataFrame is missing required column(s): {sorted(missing)}. " |
| 205 | + "Expected at minimum: 'query_id' and 'query_text'." |
| 206 | + ) |
| 207 | + return data.copy() |
| 208 | + |
| 209 | + if isinstance(data, list) and data: |
| 210 | + first = data[0] |
| 211 | + if isinstance(first, str): |
| 212 | + return pd.DataFrame( |
| 213 | + { |
| 214 | + "query_id": [f"q{i}" for i in range(len(data))], |
| 215 | + "query_text": list(data), |
| 216 | + } |
| 217 | + ) |
| 218 | + if isinstance(first, (tuple, list)) and len(first) == 2: |
| 219 | + return pd.DataFrame(data, columns=["query_id", "query_text"]) |
| 220 | + |
| 221 | + raise TypeError( |
| 222 | + f"Unsupported input type {type(data).__name__!r}. " |
| 223 | + "Pass a pd.DataFrame with 'query_id' and 'query_text' columns, " |
| 224 | + "a list[str], or a list[tuple[str, str]]." |
| 225 | + ) |
| 226 | + |
| 227 | + def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: |
| 228 | + """Generate sub-queries for every row and explode to one row per sub-query.""" |
| 229 | + self._ensure_client() |
| 230 | + system_prompt = self._build_system_prompt() |
| 231 | + |
| 232 | + passthrough_cols = [c for c in data.columns if c not in ("query_id", "query_text")] |
| 233 | + rows: List[dict[str, Any]] = [] |
| 234 | + |
| 235 | + for _, row in data.iterrows(): |
| 236 | + subqueries = self._generate_one(row["query_text"], system_prompt) |
| 237 | + for idx, sq in enumerate(subqueries): |
| 238 | + new_row: dict[str, Any] = { |
| 239 | + "query_id": row["query_id"], |
| 240 | + "query_text": row["query_text"], |
| 241 | + "subquery_idx": idx, |
| 242 | + "subquery_text": sq, |
| 243 | + } |
| 244 | + for col in passthrough_cols: |
| 245 | + new_row[col] = row[col] |
| 246 | + rows.append(new_row) |
| 247 | + |
| 248 | + if not rows: |
| 249 | + return pd.DataFrame(columns=["query_id", "query_text", "subquery_idx", "subquery_text"]) |
| 250 | + |
| 251 | + return pd.DataFrame(rows) |
| 252 | + |
| 253 | + def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: |
| 254 | + return data |
| 255 | + |
| 256 | + # ------------------------------------------------------------------ |
| 257 | + # Internal helpers |
| 258 | + # ------------------------------------------------------------------ |
| 259 | + |
| 260 | + def _ensure_client(self) -> None: |
| 261 | + """Lazily create the OpenAI client (once per instance).""" |
| 262 | + if self._client is not None: |
| 263 | + return |
| 264 | + try: |
| 265 | + from openai import OpenAI |
| 266 | + except ImportError as exc: |
| 267 | + raise ImportError( |
| 268 | + "SubQueryGeneratorOperator requires 'openai'. " |
| 269 | + "Install it with: pip install 'openai>=1.0'" |
| 270 | + ) from exc |
| 271 | + |
| 272 | + api_key = self._api_key |
| 273 | + if api_key is not None and api_key.strip().startswith("os.environ/"): |
| 274 | + var = api_key.strip().removeprefix("os.environ/") |
| 275 | + api_key = os.environ[var] |
| 276 | + |
| 277 | + self._client = OpenAI( |
| 278 | + api_key=api_key, |
| 279 | + **({"base_url": self._base_url} if self._base_url is not None else {}), |
| 280 | + ) |
| 281 | + |
| 282 | + def _build_system_prompt(self) -> str: |
| 283 | + template = self._system_prompt_override or _PROMPTS[self._strategy] |
| 284 | + return template.format(max_subqueries=self._max_subqueries) |
| 285 | + |
| 286 | + def _generate_one(self, query: str, system_prompt: str) -> List[str]: |
| 287 | + """Call the LLM and return a list of sub-query strings for *query*.""" |
| 288 | + call_kwargs: dict[str, Any] = dict( |
| 289 | + model=self._llm_model, |
| 290 | + messages=[ |
| 291 | + {"role": "system", "content": system_prompt}, |
| 292 | + {"role": "user", "content": f"Query: {query}"}, |
| 293 | + ], |
| 294 | + ) |
| 295 | + if self._max_tokens is not None: |
| 296 | + call_kwargs["max_tokens"] = self._max_tokens |
| 297 | + |
| 298 | + response = self._client.chat.completions.create(**call_kwargs) |
| 299 | + raw = response.choices[0].message.content.strip() |
| 300 | + return _parse_json_list(raw, fallback=query) |
| 301 | + |
| 302 | + |
| 303 | +# --------------------------------------------------------------------------- |
| 304 | +# Module-level helpers (no instance state — easier to test in isolation) |
| 305 | +# --------------------------------------------------------------------------- |
| 306 | + |
| 307 | + |
| 308 | +def _parse_json_list(raw: str, *, fallback: str) -> List[str]: |
| 309 | + """Parse a JSON array from *raw*, stripping markdown fences if present. |
| 310 | +
|
| 311 | + Returns *[fallback]* when parsing fails so downstream stages always |
| 312 | + receive at least one sub-query row. |
| 313 | + """ |
| 314 | + text = raw |
| 315 | + for fence in ("```json", "```"): |
| 316 | + if text.startswith(fence): |
| 317 | + text = text[len(fence):] |
| 318 | + break |
| 319 | + if text.endswith("```"): |
| 320 | + text = text[:-3] |
| 321 | + text = text.strip() |
| 322 | + |
| 323 | + try: |
| 324 | + parsed = json.loads(text) |
| 325 | + if isinstance(parsed, list) and parsed and all(isinstance(s, str) for s in parsed): |
| 326 | + return parsed |
| 327 | + except json.JSONDecodeError: |
| 328 | + pass |
| 329 | + |
| 330 | + return [fallback] |
0 commit comments