diff --git a/.semversioner/next-release/patch-20250423234829757628.json b/.semversioner/next-release/patch-20250423234829757628.json new file mode 100644 index 0000000000..7af4894f9e --- /dev/null +++ b/.semversioner/next-release/patch-20250423234829757628.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fixes to basic search." +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index cea7ba8ea3..627544ec79 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -42,6 +42,7 @@ class BasicSearchDefaults: prompt: None = None k: int = 10 + max_context_tokens: int = 12_000 chat_model_id: str = DEFAULT_CHAT_MODEL_ID embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID diff --git a/graphrag/config/models/basic_search_config.py b/graphrag/config/models/basic_search_config.py index 8221cd3ff5..66a1e68577 100644 --- a/graphrag/config/models/basic_search_config.py +++ b/graphrag/config/models/basic_search_config.py @@ -27,3 +27,7 @@ class BasicSearchConfig(BaseModel): description="The number of text units to include in search context.", default=graphrag_config_defaults.basic_search.k, ) + max_context_tokens: int = Field( + description="The maximum tokens.", + default=graphrag_config_defaults.basic_search.max_context_tokens, + ) diff --git a/graphrag/index/text_splitting/text_splitting.py b/graphrag/index/text_splitting/text_splitting.py index 1632904637..57f2f23659 100644 --- a/graphrag/index/text_splitting/text_splitting.py +++ b/graphrag/index/text_splitting/text_splitting.py @@ -152,6 +152,8 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]: while start_idx < len(input_ids): chunk_text = tokenizer.decode(list(chunk_ids)) result.append(chunk_text) # Append chunked text as string + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] @@ -186,6 +188,8 @@ def split_multiple_texts_on_tokens( chunk_text = tokenizer.decode([id for _, id in chunk_ids]) doc_indices = list({doc_idx for doc_idx, _ in chunk_ids}) result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids))) + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] diff --git a/graphrag/prompts/query/basic_search_system_prompt.py b/graphrag/prompts/query/basic_search_system_prompt.py index f98ea0582c..a20fb6ad10 100644 --- a/graphrag/prompts/query/basic_search_system_prompt.py +++ b/graphrag/prompts/query/basic_search_system_prompt.py @@ -11,23 +11,25 @@ ---Goal--- -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. +Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data tables appropriate for the response length and format. -If you don't know the answer, just say so. Do not make anything up. +You should use the data provided in the data tables below as the primary context for generating the response. + +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. Points supported by data should list their data references as follows: -"This is an example sentence supported by multiple text references [Data: Sources (record ids)]." +"This is an example sentence supported by multiple data references [Data: Sources (record ids)]." Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 15 and 16 represent the id (not the index) of the relevant data record. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. -Do not include information where the supporting text for it is not provided. +Do not include information where the supporting evidence for it is not provided. ---Target response length and format--- @@ -42,23 +44,26 @@ ---Goal--- -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. +Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data appropriate for the response length and format. + +You should use the data provided in the data tables below as the primary context for generating the response. -If you don't know the answer, just say so. Do not make anything up. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. Points supported by data should list their data references as follows: -"This is an example sentence supported by multiple text references [Data: Sources (record ids)]." +"This is an example sentence supported by multiple data references [Data: Sources (record ids)]." Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. -where 15 and 16 represent the id (not the index) of the relevant data record. +Do not include information where the supporting evidence for it is not provided. -Do not include information where the supporting text for it is not provided. ---Target response length and format--- diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 76fa1f430f..907c83cacf 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -275,6 +275,7 @@ def get_basic_search_engine( text_unit_embeddings: BaseVectorStore, config: GraphRagConfig, system_prompt: str | None = None, + response_type: str = "multiple paragraphs", callbacks: list[QueryCallbacks] | None = None, ) -> BasicSearch: """Create a basic search engine based on data + configuration.""" @@ -312,6 +313,7 @@ def get_basic_search_engine( return BasicSearch( model=chat_model, system_prompt=system_prompt, + response_type=response_type, context_builder=BasicSearchContext( text_embedder=embedding_model, text_unit_embeddings=text_unit_embeddings, @@ -323,6 +325,7 @@ def get_basic_search_engine( context_builder_params={ "embedding_vectorstore_key": "id", "k": bs_config.k, + "max_context_tokens": bs_config.max_context_tokens, }, callbacks=callbacks, ) diff --git a/graphrag/query/structured_search/basic_search/basic_context.py b/graphrag/query/structured_search/basic_search/basic_context.py index 0dd139ea86..d08bdc2e73 100644 --- a/graphrag/query/structured_search/basic_search/basic_context.py +++ b/graphrag/query/structured_search/basic_search/basic_context.py @@ -3,6 +3,9 @@ """Basic Context Builder implementation.""" +import logging +from typing import cast + import pandas as pd import tiktoken @@ -13,8 +16,11 @@ ContextBuilderResult, ) from graphrag.query.context_builder.conversation_history import ConversationHistory +from graphrag.query.llm.text_utils import num_tokens from graphrag.vector_stores.base import BaseVectorStore +log = logging.getLogger(__name__) + class BasicSearchContext(BasicContextBuilder): """Class representing the Basic Search Context Builder.""" @@ -32,30 +38,76 @@ def __init__( self.text_units = text_units self.text_unit_embeddings = text_unit_embeddings self.embedding_vectorstore_key = embedding_vectorstore_key + self.text_id_map = self._map_ids() def build_context( self, query: str, conversation_history: ConversationHistory | None = None, + k: int = 10, + max_context_tokens: int = 12_000, + context_name: str = "Sources", + column_delimiter: str = "|", + text_id_col: str = "source_id", + text_col: str = "text", **kwargs, ) -> ContextBuilderResult: - """Build the context for the local search mode.""" - search_results = self.text_unit_embeddings.similarity_search_by_text( - text=query, - text_embedder=lambda t: self.text_embedder.embed(t), - k=kwargs.get("k", 10), + """Build the context for the basic search mode.""" + if query != "": + related_texts = self.text_unit_embeddings.similarity_search_by_text( + text=query, + text_embedder=lambda t: self.text_embedder.embed(t), + k=k, + ) + related_text_list = [ + { + text_id_col: self.text_id_map[f"{chunk.document.id}"], + text_col: chunk.document.text, + } + for chunk in related_texts + ] + related_text_df = pd.DataFrame(related_text_list) + else: + related_text_df = pd.DataFrame({ + text_id_col: [], + text_col: [], + }) + + # add these related text chunks into context until we fill up the context window + current_tokens = 0 + text_ids = [] + current_tokens = num_tokens( + text_id_col + column_delimiter + text_col + "\n", self.token_encoder ) - # we don't have a friendly id on text_units, so just copy the index - sources = [ - {"id": str(search_results.index(r)), "text": r.document.text} - for r in search_results - ] - # make a delimited table for the context; this imitates graphrag context building - table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources] + for i, row in related_text_df.iterrows(): + text = row[text_id_col] + column_delimiter + row[text_col] + "\n" + tokens = num_tokens(text, self.token_encoder) + if current_tokens + tokens > max_context_tokens: + msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state" + log.info(msg) + break - columns = pd.Index(["id", "text"]) + current_tokens += tokens + text_ids.append(i) + final_text_df = cast( + "pd.DataFrame", + related_text_df[related_text_df.index.isin(text_ids)].reset_index( + drop=True + ), + ) + final_text = final_text_df.to_csv( + index=False, escapechar="\\", sep=column_delimiter + ) return ContextBuilderResult( - context_chunks="\n\n".join(table), - context_records={"sources": pd.DataFrame(sources, columns=columns)}, + context_chunks=final_text, + context_records={context_name: final_text_df}, ) + + def _map_ids(self) -> dict[str, str]: + """Map id to short id in the text units.""" + id_map = {} + text_units = self.text_units or [] + for unit in text_units: + id_map[unit.id] = unit.short_id + return id_map diff --git a/graphrag/query/structured_search/basic_search/search.py b/graphrag/query/structured_search/basic_search/search.py index e2fb29c012..a5dca578d3 100644 --- a/graphrag/query/structured_search/basic_search/search.py +++ b/graphrag/query/structured_search/basic_search/search.py @@ -108,6 +108,9 @@ async def search( llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), output_tokens=sum(output_tokens.values()), + llm_calls_categories=llm_calls, + prompt_tokens_categories=prompt_tokens, + output_tokens_categories=output_tokens, ) except Exception: @@ -120,6 +123,9 @@ async def search( llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), output_tokens=0, + llm_calls_categories=llm_calls, + prompt_tokens_categories=prompt_tokens, + output_tokens_categories=output_tokens, ) async def stream_search( diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 2099e3592e..10d6234327 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -213,7 +213,7 @@ async def search( primer_context, token_ct = await self.context_builder.build_context(query) llm_calls["build_context"] = token_ct["llm_calls"] prompt_tokens["build_context"] = token_ct["prompt_tokens"] - output_tokens["build_context"] = token_ct["prompt_tokens"] + output_tokens["build_context"] = token_ct["output_tokens"] primer_response = await self.primer.search( query=query, top_k_reports=primer_context diff --git a/tests/unit/indexing/text_splitting/test_text_splitting.py b/tests/unit/indexing/text_splitting/test_text_splitting.py index da87d47350..10a5a06344 100644 --- a/tests/unit/indexing/text_splitting/test_text_splitting.py +++ b/tests/unit/indexing/text_splitting/test_text_splitting.py @@ -136,7 +136,6 @@ def test_split_single_text_on_tokens(): " by this t", "his test o", "est only.", - "nly.", ] result = split_single_text_on_tokens(text=text, tokenizer=tokenizer) @@ -197,7 +196,6 @@ def decode(tokens: list[int]) -> str: " this test", " test only", " only.", - ".", ] result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)