Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250423234829757628.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fixes to basic search."
}
1 change: 1 addition & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class BasicSearchDefaults:

prompt: None = None
k: int = 10
max_context_tokens: int = 12_000
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID

Expand Down
4 changes: 4 additions & 0 deletions graphrag/config/models/basic_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ class BasicSearchConfig(BaseModel):
description="The number of text units to include in search context.",
default=graphrag_config_defaults.basic_search.k,
)
max_context_tokens: int = Field(
description="The maximum tokens.",
default=graphrag_config_defaults.basic_search.max_context_tokens,
)
4 changes: 4 additions & 0 deletions graphrag/index/text_splitting/text_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
while start_idx < len(input_ids):
chunk_text = tokenizer.decode(list(chunk_ids))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
Expand Down Expand Up @@ -186,6 +188,8 @@ def split_multiple_texts_on_tokens(
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
Expand Down
29 changes: 17 additions & 12 deletions graphrag/prompts/query/basic_search_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,25 @@

---Goal---

Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data tables appropriate for the response length and format.

If you don't know the answer, just say so. Do not make anything up.
You should use the data provided in the data tables below as the primary context for generating the response.

If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"

where 15 and 16 represent the id (not the index) of the relevant data record.
where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.

Do not include information where the supporting text for it is not provided.
Do not include information where the supporting evidence for it is not provided.


---Target response length and format---
Expand All @@ -42,23 +44,26 @@

---Goal---

Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data appropriate for the response length and format.

You should use the data provided in the data tables below as the primary context for generating the response.

If you don't know the answer, just say so. Do not make anything up.
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"

where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.

where 15 and 16 represent the id (not the index) of the relevant data record.
Do not include information where the supporting evidence for it is not provided.

Do not include information where the supporting text for it is not provided.

---Target response length and format---

Expand Down
3 changes: 3 additions & 0 deletions graphrag/query/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def get_basic_search_engine(
text_unit_embeddings: BaseVectorStore,
config: GraphRagConfig,
system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
callbacks: list[QueryCallbacks] | None = None,
) -> BasicSearch:
"""Create a basic search engine based on data + configuration."""
Expand Down Expand Up @@ -312,6 +313,7 @@ def get_basic_search_engine(
return BasicSearch(
model=chat_model,
system_prompt=system_prompt,
response_type=response_type,
context_builder=BasicSearchContext(
text_embedder=embedding_model,
text_unit_embeddings=text_unit_embeddings,
Expand All @@ -323,6 +325,7 @@ def get_basic_search_engine(
context_builder_params={
"embedding_vectorstore_key": "id",
"k": bs_config.k,
"max_context_tokens": bs_config.max_context_tokens,
},
callbacks=callbacks,
)
82 changes: 67 additions & 15 deletions graphrag/query/structured_search/basic_search/basic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

"""Basic Context Builder implementation."""

import logging
from typing import cast

import pandas as pd
import tiktoken

Expand All @@ -13,8 +16,11 @@
ContextBuilderResult,
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.llm.text_utils import num_tokens
from graphrag.vector_stores.base import BaseVectorStore

log = logging.getLogger(__name__)


class BasicSearchContext(BasicContextBuilder):
"""Class representing the Basic Search Context Builder."""
Expand All @@ -32,30 +38,76 @@ def __init__(
self.text_units = text_units
self.text_unit_embeddings = text_unit_embeddings
self.embedding_vectorstore_key = embedding_vectorstore_key
self.text_id_map = self._map_ids()

def build_context(
self,
query: str,
conversation_history: ConversationHistory | None = None,
k: int = 10,
max_context_tokens: int = 12_000,
context_name: str = "Sources",
column_delimiter: str = "|",
text_id_col: str = "source_id",
text_col: str = "text",
**kwargs,
) -> ContextBuilderResult:
"""Build the context for the local search mode."""
search_results = self.text_unit_embeddings.similarity_search_by_text(
text=query,
text_embedder=lambda t: self.text_embedder.embed(t),
k=kwargs.get("k", 10),
"""Build the context for the basic search mode."""
if query != "":
related_texts = self.text_unit_embeddings.similarity_search_by_text(
text=query,
text_embedder=lambda t: self.text_embedder.embed(t),
k=k,
)
related_text_list = [
{
text_id_col: self.text_id_map[f"{chunk.document.id}"],
text_col: chunk.document.text,
}
for chunk in related_texts
]
related_text_df = pd.DataFrame(related_text_list)
else:
related_text_df = pd.DataFrame({
text_id_col: [],
text_col: [],
})

# add these related text chunks into context until we fill up the context window
current_tokens = 0
text_ids = []
current_tokens = num_tokens(
text_id_col + column_delimiter + text_col + "\n", self.token_encoder
)
# we don't have a friendly id on text_units, so just copy the index
sources = [
{"id": str(search_results.index(r)), "text": r.document.text}
for r in search_results
]
# make a delimited table for the context; this imitates graphrag context building
table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources]
for i, row in related_text_df.iterrows():
text = row[text_id_col] + column_delimiter + row[text_col] + "\n"
tokens = num_tokens(text, self.token_encoder)
if current_tokens + tokens > max_context_tokens:
msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state"
log.info(msg)
break

columns = pd.Index(["id", "text"])
current_tokens += tokens
text_ids.append(i)
final_text_df = cast(
"pd.DataFrame",
related_text_df[related_text_df.index.isin(text_ids)].reset_index(
drop=True
),
)
final_text = final_text_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)

return ContextBuilderResult(
context_chunks="\n\n".join(table),
context_records={"sources": pd.DataFrame(sources, columns=columns)},
context_chunks=final_text,
context_records={context_name: final_text_df},
)

def _map_ids(self) -> dict[str, str]:
"""Map id to short id in the text units."""
id_map = {}
text_units = self.text_units or []
for unit in text_units:
id_map[unit.id] = unit.short_id
return id_map
6 changes: 6 additions & 0 deletions graphrag/query/structured_search/basic_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ async def search(
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)

except Exception:
Expand All @@ -120,6 +123,9 @@ async def search(
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)

async def stream_search(
Expand Down
2 changes: 1 addition & 1 deletion graphrag/query/structured_search/drift_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def search(
primer_context, token_ct = await self.context_builder.build_context(query)
llm_calls["build_context"] = token_ct["llm_calls"]
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["output_tokens"]

primer_response = await self.primer.search(
query=query, top_k_reports=primer_context
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/indexing/text_splitting/test_text_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def test_split_single_text_on_tokens():
" by this t",
"his test o",
"est only.",
"nly.",
]

result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
Expand Down Expand Up @@ -197,7 +196,6 @@ def decode(tokens: list[int]) -> str:
" this test",
" test only",
" only.",
".",
]

result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
Expand Down
Loading