From 24860718724a6718f0792dace5e2ad13033ed55b Mon Sep 17 00:00:00 2001 From: ha2trinh Date: Tue, 22 Apr 2025 15:54:51 -0700 Subject: [PATCH 1/7] fixed token count for drift search --- graphrag/query/structured_search/drift_search/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 2099e3592e..10d6234327 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -213,7 +213,7 @@ async def search( primer_context, token_ct = await self.context_builder.build_context(query) llm_calls["build_context"] = token_ct["llm_calls"] prompt_tokens["build_context"] = token_ct["prompt_tokens"] - output_tokens["build_context"] = token_ct["prompt_tokens"] + output_tokens["build_context"] = token_ct["output_tokens"] primer_response = await self.primer.search( query=query, top_k_reports=primer_context From 907f637956188da9c4fb17f162f83eb7ebe37c63 Mon Sep 17 00:00:00 2001 From: ha2trinh Date: Tue, 22 Apr 2025 16:10:31 -0700 Subject: [PATCH 2/7] basic search fixes --- graphrag/config/defaults.py | 1 + graphrag/config/models/basic_search_config.py | 4 ++ graphrag/query/factory.py | 3 + .../basic_search/basic_context.py | 65 ++++++++++++++----- .../structured_search/basic_search/search.py | 7 ++ 5 files changed, 63 insertions(+), 17 deletions(-) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index cea7ba8ea3..627544ec79 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -42,6 +42,7 @@ class BasicSearchDefaults: prompt: None = None k: int = 10 + max_context_tokens: int = 12_000 chat_model_id: str = DEFAULT_CHAT_MODEL_ID embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID diff --git a/graphrag/config/models/basic_search_config.py b/graphrag/config/models/basic_search_config.py index 8221cd3ff5..66a1e68577 100644 --- a/graphrag/config/models/basic_search_config.py +++ b/graphrag/config/models/basic_search_config.py @@ -27,3 +27,7 @@ class BasicSearchConfig(BaseModel): description="The number of text units to include in search context.", default=graphrag_config_defaults.basic_search.k, ) + max_context_tokens: int = Field( + description="The maximum tokens.", + default=graphrag_config_defaults.basic_search.max_context_tokens, + ) diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 76fa1f430f..907c83cacf 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -275,6 +275,7 @@ def get_basic_search_engine( text_unit_embeddings: BaseVectorStore, config: GraphRagConfig, system_prompt: str | None = None, + response_type: str = "multiple paragraphs", callbacks: list[QueryCallbacks] | None = None, ) -> BasicSearch: """Create a basic search engine based on data + configuration.""" @@ -312,6 +313,7 @@ def get_basic_search_engine( return BasicSearch( model=chat_model, system_prompt=system_prompt, + response_type=response_type, context_builder=BasicSearchContext( text_embedder=embedding_model, text_unit_embeddings=text_unit_embeddings, @@ -323,6 +325,7 @@ def get_basic_search_engine( context_builder_params={ "embedding_vectorstore_key": "id", "k": bs_config.k, + "max_context_tokens": bs_config.max_context_tokens, }, callbacks=callbacks, ) diff --git a/graphrag/query/structured_search/basic_search/basic_context.py b/graphrag/query/structured_search/basic_search/basic_context.py index 0dd139ea86..707f63568d 100644 --- a/graphrag/query/structured_search/basic_search/basic_context.py +++ b/graphrag/query/structured_search/basic_search/basic_context.py @@ -3,6 +3,7 @@ """Basic Context Builder implementation.""" +import logging import pandas as pd import tiktoken @@ -14,7 +15,9 @@ ) from graphrag.query.context_builder.conversation_history import ConversationHistory from graphrag.vector_stores.base import BaseVectorStore +from graphrag.query.llm.text_utils import num_tokens +log = logging.getLogger(__name__) class BasicSearchContext(BasicContextBuilder): """Class representing the Basic Search Context Builder.""" @@ -32,30 +35,58 @@ def __init__( self.text_units = text_units self.text_unit_embeddings = text_unit_embeddings self.embedding_vectorstore_key = embedding_vectorstore_key + self.text_id_map = self._map_ids() def build_context( self, query: str, conversation_history: ConversationHistory | None = None, + k: int = 10, + max_context_tokens: int = 12_000, + context_name: str = "Sources", + column_delimiter: str = "|", + text_id_col: str = "source_id", + text_col: str = "text", **kwargs, ) -> ContextBuilderResult: - """Build the context for the local search mode.""" - search_results = self.text_unit_embeddings.similarity_search_by_text( - text=query, - text_embedder=lambda t: self.text_embedder.embed(t), - k=kwargs.get("k", 10), - ) - # we don't have a friendly id on text_units, so just copy the index - sources = [ - {"id": str(search_results.index(r)), "text": r.document.text} - for r in search_results - ] - # make a delimited table for the context; this imitates graphrag context building - table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources] - - columns = pd.Index(["id", "text"]) + """Build the context for the basic search mode.""" + if query != "": + related_texts = self.text_unit_embeddings.similarity_search_by_text( + text=query, + text_embedder=lambda t: self.text_embedder.embed(t), + k=k, + ) + related_text_list = [{text_id_col: self.text_id_map[chunk.document.id], text_col: chunk.document.text} for chunk in related_texts] + related_text_df = pd.DataFrame(related_text_list) + else: + related_text_df = pd.DataFrame(columns=[text_id_col, text_col], data=[]) + + # add these related text chunks into context until we fill up the context window + current_tokens = 0 + text_ids = [] + current_tokens = num_tokens(text_id_col + column_delimiter + text_col + "\n", self.token_encoder) + for i, row in related_text_df.iterrows(): + text = row[text_id_col] + column_delimiter + row[text_col] + "\n" + tokens = num_tokens(text, self.token_encoder) + if current_tokens + tokens > max_context_tokens: + log.info(f"Reached token limit: {current_tokens+tokens}. Reverting to previous context state") + break + + current_tokens += tokens + text_ids.append(i) + final_text_df = related_text_df[related_text_df.index.isin(text_ids)].reset_index(drop=True) + final_text = final_text_df.to_csv(index=False, escapechar='\\', sep=column_delimiter) return ContextBuilderResult( - context_chunks="\n\n".join(table), - context_records={"sources": pd.DataFrame(sources, columns=columns)}, + context_chunks=final_text, + context_records={context_name: final_text_df}, ) + + def _map_ids(self) -> dict[str, str]: + """Map id to short id in the text units""" + id_map = {} + for unit in self.text_units: + id_map[unit.id] = unit.short_id + return id_map + + \ No newline at end of file diff --git a/graphrag/query/structured_search/basic_search/search.py b/graphrag/query/structured_search/basic_search/search.py index e2fb29c012..65635df01e 100644 --- a/graphrag/query/structured_search/basic_search/search.py +++ b/graphrag/query/structured_search/basic_search/search.py @@ -108,6 +108,10 @@ async def search( llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), output_tokens=sum(output_tokens.values()), + llm_calls_categories=llm_calls, + prompt_tokens_categories=prompt_tokens, + output_tokens_categories=output_tokens, + ) except Exception: @@ -120,6 +124,9 @@ async def search( llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), output_tokens=0, + llm_calls_categories=llm_calls, + prompt_tokens_categories=prompt_tokens, + output_tokens_categories=output_tokens, ) async def stream_search( From 487ecc01e957d5cc94aa9839d2fae7d9a561923f Mon Sep 17 00:00:00 2001 From: ha2trinh Date: Tue, 22 Apr 2025 16:14:32 -0700 Subject: [PATCH 3/7] updated basic search prompt --- .../query/basic_search_system_prompt.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/graphrag/prompts/query/basic_search_system_prompt.py b/graphrag/prompts/query/basic_search_system_prompt.py index f98ea0582c..a20fb6ad10 100644 --- a/graphrag/prompts/query/basic_search_system_prompt.py +++ b/graphrag/prompts/query/basic_search_system_prompt.py @@ -11,23 +11,25 @@ ---Goal--- -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. +Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data tables appropriate for the response length and format. -If you don't know the answer, just say so. Do not make anything up. +You should use the data provided in the data tables below as the primary context for generating the response. + +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. Points supported by data should list their data references as follows: -"This is an example sentence supported by multiple text references [Data: Sources (record ids)]." +"This is an example sentence supported by multiple data references [Data: Sources (record ids)]." Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 15 and 16 represent the id (not the index) of the relevant data record. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. -Do not include information where the supporting text for it is not provided. +Do not include information where the supporting evidence for it is not provided. ---Target response length and format--- @@ -42,23 +44,26 @@ ---Goal--- -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. +Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data appropriate for the response length and format. + +You should use the data provided in the data tables below as the primary context for generating the response. -If you don't know the answer, just say so. Do not make anything up. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. Points supported by data should list their data references as follows: -"This is an example sentence supported by multiple text references [Data: Sources (record ids)]." +"This is an example sentence supported by multiple data references [Data: Sources (record ids)]." Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. -where 15 and 16 represent the id (not the index) of the relevant data record. +Do not include information where the supporting evidence for it is not provided. -Do not include information where the supporting text for it is not provided. ---Target response length and format--- From c72a241af9687463faceb00ad5e3a4aa89ef285d Mon Sep 17 00:00:00 2001 From: ha2trinh Date: Tue, 22 Apr 2025 16:28:04 -0700 Subject: [PATCH 4/7] fixed text splitting logic --- graphrag/index/text_splitting/text_splitting.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graphrag/index/text_splitting/text_splitting.py b/graphrag/index/text_splitting/text_splitting.py index 1632904637..57f2f23659 100644 --- a/graphrag/index/text_splitting/text_splitting.py +++ b/graphrag/index/text_splitting/text_splitting.py @@ -152,6 +152,8 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]: while start_idx < len(input_ids): chunk_text = tokenizer.decode(list(chunk_ids)) result.append(chunk_text) # Append chunked text as string + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] @@ -186,6 +188,8 @@ def split_multiple_texts_on_tokens( chunk_text = tokenizer.decode([id for _, id in chunk_ids]) doc_indices = list({doc_idx for doc_idx, _ in chunk_ids}) result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids))) + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] From 49a694fef11747ba59e8acea165335829f3ff330 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Wed, 23 Apr 2025 16:48:09 -0700 Subject: [PATCH 5/7] Lint/format --- .../basic_search/basic_context.py | 49 +++++++++++++------ .../structured_search/basic_search/search.py | 1 - 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/graphrag/query/structured_search/basic_search/basic_context.py b/graphrag/query/structured_search/basic_search/basic_context.py index 707f63568d..d08bdc2e73 100644 --- a/graphrag/query/structured_search/basic_search/basic_context.py +++ b/graphrag/query/structured_search/basic_search/basic_context.py @@ -4,6 +4,8 @@ """Basic Context Builder implementation.""" import logging +from typing import cast + import pandas as pd import tiktoken @@ -14,11 +16,12 @@ ContextBuilderResult, ) from graphrag.query.context_builder.conversation_history import ConversationHistory -from graphrag.vector_stores.base import BaseVectorStore from graphrag.query.llm.text_utils import num_tokens +from graphrag.vector_stores.base import BaseVectorStore log = logging.getLogger(__name__) + class BasicSearchContext(BasicContextBuilder): """Class representing the Basic Search Context Builder.""" @@ -56,37 +59,55 @@ def build_context( text_embedder=lambda t: self.text_embedder.embed(t), k=k, ) - related_text_list = [{text_id_col: self.text_id_map[chunk.document.id], text_col: chunk.document.text} for chunk in related_texts] + related_text_list = [ + { + text_id_col: self.text_id_map[f"{chunk.document.id}"], + text_col: chunk.document.text, + } + for chunk in related_texts + ] related_text_df = pd.DataFrame(related_text_list) else: - related_text_df = pd.DataFrame(columns=[text_id_col, text_col], data=[]) - + related_text_df = pd.DataFrame({ + text_id_col: [], + text_col: [], + }) + # add these related text chunks into context until we fill up the context window current_tokens = 0 text_ids = [] - current_tokens = num_tokens(text_id_col + column_delimiter + text_col + "\n", self.token_encoder) + current_tokens = num_tokens( + text_id_col + column_delimiter + text_col + "\n", self.token_encoder + ) for i, row in related_text_df.iterrows(): text = row[text_id_col] + column_delimiter + row[text_col] + "\n" tokens = num_tokens(text, self.token_encoder) if current_tokens + tokens > max_context_tokens: - log.info(f"Reached token limit: {current_tokens+tokens}. Reverting to previous context state") + msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state" + log.info(msg) break - + current_tokens += tokens text_ids.append(i) - final_text_df = related_text_df[related_text_df.index.isin(text_ids)].reset_index(drop=True) - final_text = final_text_df.to_csv(index=False, escapechar='\\', sep=column_delimiter) + final_text_df = cast( + "pd.DataFrame", + related_text_df[related_text_df.index.isin(text_ids)].reset_index( + drop=True + ), + ) + final_text = final_text_df.to_csv( + index=False, escapechar="\\", sep=column_delimiter + ) return ContextBuilderResult( context_chunks=final_text, context_records={context_name: final_text_df}, ) - + def _map_ids(self) -> dict[str, str]: - """Map id to short id in the text units""" + """Map id to short id in the text units.""" id_map = {} - for unit in self.text_units: + text_units = self.text_units or [] + for unit in text_units: id_map[unit.id] = unit.short_id return id_map - - \ No newline at end of file diff --git a/graphrag/query/structured_search/basic_search/search.py b/graphrag/query/structured_search/basic_search/search.py index 65635df01e..a5dca578d3 100644 --- a/graphrag/query/structured_search/basic_search/search.py +++ b/graphrag/query/structured_search/basic_search/search.py @@ -111,7 +111,6 @@ async def search( llm_calls_categories=llm_calls, prompt_tokens_categories=prompt_tokens, output_tokens_categories=output_tokens, - ) except Exception: From 1ad1e2ed95a311a9f13f98b054c0b9099670d394 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Wed, 23 Apr 2025 16:48:43 -0700 Subject: [PATCH 6/7] Semver --- .semversioner/next-release/patch-20250423234829757628.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .semversioner/next-release/patch-20250423234829757628.json diff --git a/.semversioner/next-release/patch-20250423234829757628.json b/.semversioner/next-release/patch-20250423234829757628.json new file mode 100644 index 0000000000..7af4894f9e --- /dev/null +++ b/.semversioner/next-release/patch-20250423234829757628.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fixes to basic search." +} From af8e6ca90d882b87906c0d0c86edc211b13dd81f Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Wed, 23 Apr 2025 16:55:47 -0700 Subject: [PATCH 7/7] Fix text splitting tests --- tests/unit/indexing/text_splitting/test_text_splitting.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/indexing/text_splitting/test_text_splitting.py b/tests/unit/indexing/text_splitting/test_text_splitting.py index da87d47350..10a5a06344 100644 --- a/tests/unit/indexing/text_splitting/test_text_splitting.py +++ b/tests/unit/indexing/text_splitting/test_text_splitting.py @@ -136,7 +136,6 @@ def test_split_single_text_on_tokens(): " by this t", "his test o", "est only.", - "nly.", ] result = split_single_text_on_tokens(text=text, tokenizer=tokenizer) @@ -197,7 +196,6 @@ def decode(tokens: list[int]) -> str: " this test", " test only", " only.", - ".", ] result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)