diff --git a/.vscode/launch.json b/.vscode/launch.json index 9f949cb12f..cd9d85d17c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -24,7 +24,7 @@ "--root", "${input:root_folder}", "--method", "${input:query_method}", - "--query", "${input:query}" + "${input:query}" ] }, { diff --git a/docs/examples_notebooks/api_overview.ipynb b/docs/examples_notebooks/api_overview.ipynb index 2a0c0f15de..abcd7832fc 100644 --- a/docs/examples_notebooks/api_overview.ipynb +++ b/docs/examples_notebooks/api_overview.ipynb @@ -28,10 +28,11 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", - "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", + "\n", + "import graphrag.api as api" ] }, { diff --git a/docs/examples_notebooks/input_documents.ipynb b/docs/examples_notebooks/input_documents.ipynb index 5657770eaf..505c0fe1f3 100644 --- a/docs/examples_notebooks/input_documents.ipynb +++ b/docs/examples_notebooks/input_documents.ipynb @@ -30,10 +30,11 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", - "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", + "\n", + "import graphrag.api as api" ] }, { diff --git a/packages/graphrag/graphrag/query/context_builder/builders.py b/packages/graphrag/graphrag/query/context_builder/builders.py index 736c18745f..b66595c9fe 100644 --- a/packages/graphrag/graphrag/query/context_builder/builders.py +++ b/packages/graphrag/graphrag/query/context_builder/builders.py @@ -24,6 +24,15 @@ class ContextBuilderResult: output_tokens: int = 0 +@dataclass +class LLMParameters: + """A class to hold LLM call parameters.""" + + llm_calls: int = 0 + prompt_tokens: int = 0 + output_tokens: int = 0 + + class GlobalContextBuilder(ABC): """Base class for global-search context builders.""" @@ -36,6 +45,22 @@ async def build_context( ) -> ContextBuilderResult: """Build the context for the global search mode.""" + @abstractmethod + async def build_context_chunks( + self, + query: str, + **kwargs, + ) -> str | list[str]: + """Build the context chunks for the global search mode.""" + + @abstractmethod + async def build_context_records( + self, + query: str, + **kwargs, + ) -> dict[str, pd.DataFrame]: + """Build the context records for the global search mode.""" + class LocalContextBuilder(ABC): """Base class for local-search context builders.""" @@ -49,6 +74,22 @@ def build_context( ) -> ContextBuilderResult: """Build the context for the local search mode.""" + @abstractmethod + def build_context_chunks( + self, + query: str, + **kwargs, + ) -> str: + """Build the context chunks for the local search mode.""" + + @abstractmethod + def build_context_records( + self, + query: str, + **kwargs, + ) -> dict[str, pd.DataFrame]: + """Build the context records for the local search mode.""" + class DRIFTContextBuilder(ABC): """Base class for DRIFT-search context builders.""" @@ -66,10 +107,13 @@ class BasicContextBuilder(ABC): """Base class for basic-search context builders.""" @abstractmethod - def build_context( + def build_context_records( self, query: str, - conversation_history: ConversationHistory | None = None, **kwargs, - ) -> ContextBuilderResult: - """Build the context for the basic search mode.""" + ) -> dict[str, pd.DataFrame]: + """Build the context records for the basic search mode.""" + + @abstractmethod + def get_llm_values(self) -> LLMParameters: + """Get the LLM call values.""" diff --git a/packages/graphrag/graphrag/query/structured_search/base.py b/packages/graphrag/graphrag/query/structured_search/base.py index 753b419f69..67108a55d7 100644 --- a/packages/graphrag/graphrag/query/structured_search/base.py +++ b/packages/graphrag/graphrag/query/structured_search/base.py @@ -90,3 +90,13 @@ async def stream_search( yield "" # This makes it an async generator. msg = "Subclasses must implement this method" raise NotImplementedError(msg) + + @abstractmethod + async def format_records( + self, + records: dict[str, pd.DataFrame], + column_delimiter: str = "|", + ) -> str | list[str]: + """Format context records into a string representation.""" + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) diff --git a/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py b/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py index b7390017fc..d2bf3ef1ca 100644 --- a/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py +++ b/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py @@ -12,9 +12,8 @@ from graphrag.language_model.protocol.base import EmbeddingModel from graphrag.query.context_builder.builders import ( BasicContextBuilder, - ContextBuilderResult, + LLMParameters, ) -from graphrag.query.context_builder.conversation_history import ConversationHistory from graphrag.tokenizer.get_tokenizer import get_tokenizer from graphrag.tokenizer.tokenizer import Tokenizer from graphrag.vector_stores.base import BaseVectorStore @@ -39,10 +38,9 @@ def __init__( self.text_unit_embeddings = text_unit_embeddings self.embedding_vectorstore_key = embedding_vectorstore_key - def build_context( + def build_context_records( self, query: str, - conversation_history: ConversationHistory | None = None, k: int = 10, max_context_tokens: int = 12_000, context_name: str = "Sources", @@ -50,7 +48,7 @@ def build_context( text_id_col: str = "id", text_col: str = "text", **kwargs, - ) -> ContextBuilderResult: + ) -> dict[str, pd.DataFrame]: """Build the context for the basic search mode.""" if query != "": related_texts = self.text_unit_embeddings.similarity_search_by_text( @@ -95,11 +93,13 @@ def build_context( drop=True ), ) - final_text = final_text_df.to_csv( - index=False, escapechar="\\", sep=column_delimiter - ) - return ContextBuilderResult( - context_chunks=final_text, - context_records={context_name.lower(): final_text_df}, + return {context_name.lower(): final_text_df} + + def get_llm_values(self) -> LLMParameters: + """Get the LLM call values.""" + return LLMParameters( + llm_calls=0, + prompt_tokens=0, + output_tokens=0, ) diff --git a/packages/graphrag/graphrag/query/structured_search/basic_search/search.py b/packages/graphrag/graphrag/query/structured_search/basic_search/search.py index ce5f656845..5b4c4b4719 100644 --- a/packages/graphrag/graphrag/query/structured_search/basic_search/search.py +++ b/packages/graphrag/graphrag/query/structured_search/basic_search/search.py @@ -15,7 +15,10 @@ ) from graphrag.query.context_builder.builders import BasicContextBuilder from graphrag.query.context_builder.conversation_history import ConversationHistory -from graphrag.query.structured_search.base import BaseSearch, SearchResult +from graphrag.query.structured_search.base import ( + BaseSearch, + SearchResult, +) from graphrag.tokenizer.tokenizer import Tokenizer logger = logging.getLogger(__name__) @@ -49,6 +52,17 @@ def __init__( self.callbacks = callbacks or [] self.response_type = response_type + async def format_records(self, records, column_delimiter="|") -> str | list[str]: + """Format context records into a string representation.""" + if len(records) == 1: + _, context_records_df = next(iter(records.items())) + + if context_records_df is not None: + return context_records_df.to_csv( + index=False, escapechar="\\", sep=column_delimiter + ) + return "" + async def search( self, query: str, @@ -59,22 +73,30 @@ async def search( start_time = time.time() search_prompt = "" llm_calls, prompt_tokens, output_tokens = {}, {}, {} + context_chunks: str | list[str] = "" + column_delimiter: str = "|" - context_result = self.context_builder.build_context( + context_records = self.context_builder.build_context_records( query=query, conversation_history=conversation_history, **kwargs, **self.context_builder_params, ) - llm_calls["build_context"] = context_result.llm_calls - prompt_tokens["build_context"] = context_result.prompt_tokens - output_tokens["build_context"] = context_result.output_tokens + llm_values = self.context_builder.get_llm_values() + llm_calls["build_context"] = llm_values.llm_calls + prompt_tokens["build_context"] = llm_values.prompt_tokens + output_tokens["build_context"] = llm_values.output_tokens logger.debug("GENERATE ANSWER: %s. QUERY: %s", start_time, query) try: + context_chunks = await self.format_records( + records=context_records, + column_delimiter=column_delimiter, + ) + search_prompt = self.system_prompt.format( - context_data=context_result.context_chunks, + context_data=context_chunks, response_type=self.response_type, ) search_messages = [ @@ -96,12 +118,12 @@ async def search( output_tokens["response"] = len(self.tokenizer.encode(response)) for callback in self.callbacks: - callback.on_context(context_result.context_records) + callback.on_context(context_records) return SearchResult( response=response, - context_data=context_result.context_records, - context_text=context_result.context_chunks, + context_data=context_records, + context_text=context_chunks, completion_time=time.time() - start_time, llm_calls=1, prompt_tokens=len(self.tokenizer.encode(search_prompt)), @@ -115,8 +137,8 @@ async def search( logger.exception("Exception in _asearch") return SearchResult( response="", - context_data=context_result.context_records, - context_text=context_result.context_chunks, + context_data=context_records, + context_text=context_chunks, completion_time=time.time() - start_time, llm_calls=1, prompt_tokens=len(self.tokenizer.encode(search_prompt)), @@ -133,22 +155,31 @@ async def stream_search( ) -> AsyncGenerator[str, None]: """Build basic search context that fits a single context window and generate answer for the user query.""" start_time = time.time() + context_chunks: str | list[str] = "" + column_delimiter: str = "|" - context_result = self.context_builder.build_context( + context_records = self.context_builder.build_context_records( query=query, conversation_history=conversation_history, **self.context_builder_params, ) + logger.debug("GENERATE ANSWER: %s. QUERY: %s", start_time, query) + + context_chunks = await self.format_records( + records=context_records, + column_delimiter=column_delimiter, + ) + search_prompt = self.system_prompt.format( - context_data=context_result.context_chunks, response_type=self.response_type + context_data=context_chunks, response_type=self.response_type ) search_messages = [ {"role": "system", "content": search_prompt}, ] for callback in self.callbacks: - callback.on_context(context_result.context_records) + callback.on_context(context_records) async for chunk_response in self.model.achat_stream( prompt=query, diff --git a/packages/graphrag/graphrag/query/structured_search/drift_search/search.py b/packages/graphrag/graphrag/query/structured_search/drift_search/search.py index 64a8e52b43..e1ca85e5cf 100644 --- a/packages/graphrag/graphrag/query/structured_search/drift_search/search.py +++ b/packages/graphrag/graphrag/query/structured_search/drift_search/search.py @@ -300,6 +300,17 @@ async def search( output_tokens_categories=output_tokens, ) + async def format_records(self, records, column_delimiter="|") -> str | list[str]: + """Format context records into a string representation.""" + if len(records) == 1: + _, context_records_df = next(iter(records.items())) + + if context_records_df is not None: + return context_records_df.to_csv( + index=False, escapechar="\\", sep=column_delimiter + ) + return "" + async def stream_search( self, query: str, conversation_history: ConversationHistory | None = None ) -> AsyncGenerator[str, None]: diff --git a/packages/graphrag/graphrag/query/structured_search/global_search/community_context.py b/packages/graphrag/graphrag/query/structured_search/global_search/community_context.py index 1709aab1a8..e6efab6dda 100644 --- a/packages/graphrag/graphrag/query/structured_search/global_search/community_context.py +++ b/packages/graphrag/graphrag/query/structured_search/global_search/community_context.py @@ -5,6 +5,8 @@ from typing import Any +import pandas as pd + from graphrag.data_model.community import Community from graphrag.data_model.community_report import CommunityReport from graphrag.data_model.entity import Entity @@ -52,7 +54,7 @@ def __init__( ) self.random_state = random_state - async def build_context( + async def build_context_chunks( self, query: str, conversation_history: ConversationHistory | None = None, @@ -70,7 +72,7 @@ async def build_context( conversation_history_user_turns_only: bool = True, conversation_history_max_turns: int | None = 5, **kwargs: Any, - ) -> ContextBuilderResult: + ) -> str | list[str]: """Prepare batches of community report data table as context data for global search.""" conversation_history_context = "" final_context_data = {} @@ -135,9 +137,133 @@ async def build_context( # Update the final context data with the provided community_context_data final_context_data.update(community_context_data) + return final_context + + async def build_context_records( + self, + query: str, + conversation_history: ConversationHistory | None = None, + use_community_summary: bool = True, + column_delimiter: str = "|", + shuffle_data: bool = True, + include_community_rank: bool = False, + min_community_rank: int = 0, + community_rank_name: str = "rank", + include_community_weight: bool = True, + community_weight_name: str = "occurrence", + normalize_community_weight: bool = True, + max_context_tokens: int = 8000, + context_name: str = "Reports", + conversation_history_user_turns_only: bool = True, + conversation_history_max_turns: int | None = 5, + **kwargs: Any, + ) -> dict[str, pd.DataFrame]: + """Prepare batches of community report data table as context data for global search.""" + conversation_history_context = "" + final_context_data = {} + if conversation_history: + # build conversation history context + ( + conversation_history_context, + conversation_history_context_data, + ) = conversation_history.build_context( + include_user_turns_only=conversation_history_user_turns_only, + max_qa_turns=conversation_history_max_turns, + column_delimiter=column_delimiter, + max_context_tokens=max_context_tokens, + recency_bias=False, + ) + if conversation_history_context != "": + final_context_data = conversation_history_context_data + + community_reports = self.community_reports + + _, community_context_data = build_community_context( + community_reports=community_reports, + entities=self.entities, + tokenizer=self.tokenizer, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + shuffle_data=shuffle_data, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + community_rank_name=community_rank_name, + include_community_weight=include_community_weight, + community_weight_name=community_weight_name, + normalize_community_weight=normalize_community_weight, + max_context_tokens=max_context_tokens, + single_batch=False, + context_name=context_name, + random_state=self.random_state, + ) + + # Update the final context data with the provided community_context_data + final_context_data.update(community_context_data) + + return final_context_data + + async def build_context( + self, + query: str, + conversation_history: ConversationHistory | None = None, + use_community_summary: bool = True, + column_delimiter: str = "|", + shuffle_data: bool = True, + include_community_rank: bool = False, + min_community_rank: int = 0, + community_rank_name: str = "rank", + include_community_weight: bool = True, + community_weight_name: str = "occurrence", + normalize_community_weight: bool = True, + max_context_tokens: int = 8000, + context_name: str = "Reports", + **kwargs: Any, + ) -> ContextBuilderResult: + """Prepare batches of community report data table as context data for global search.""" + final_context_data = {} + llm_calls, prompt_tokens, output_tokens = 0, 0, 0 + + community_reports = self.community_reports + if self.dynamic_community_selection is not None: + ( + community_reports, + dynamic_info, + ) = await self.dynamic_community_selection.select(query) + llm_calls += dynamic_info["llm_calls"] + prompt_tokens += dynamic_info["prompt_tokens"] + output_tokens += dynamic_info["output_tokens"] + + _, community_context_data = build_community_context( + community_reports=community_reports, + entities=self.entities, + tokenizer=self.tokenizer, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + shuffle_data=shuffle_data, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + community_rank_name=community_rank_name, + include_community_weight=include_community_weight, + community_weight_name=community_weight_name, + normalize_community_weight=normalize_community_weight, + max_context_tokens=max_context_tokens, + single_batch=False, + context_name=context_name, + random_state=self.random_state, + ) + + # Update the final context data with the provided community_context_data + final_context_data.update(community_context_data) + return ContextBuilderResult( - context_chunks=final_context, - context_records=final_context_data, + context_chunks=await self.build_context_chunks( + query=query, + **kwargs, + ), + context_records=await self.build_context_records( + query=query, + **kwargs, + ), llm_calls=llm_calls, prompt_tokens=prompt_tokens, output_tokens=output_tokens, diff --git a/packages/graphrag/graphrag/query/structured_search/global_search/search.py b/packages/graphrag/graphrag/query/structured_search/global_search/search.py index 86b95d0088..968454f4f9 100644 --- a/packages/graphrag/graphrag/query/structured_search/global_search/search.py +++ b/packages/graphrag/graphrag/query/structured_search/global_search/search.py @@ -96,6 +96,17 @@ def __init__( self.semaphore = asyncio.Semaphore(concurrent_coroutines) + async def format_records(self, records, column_delimiter="|") -> str | list[str]: + """Format context records into a string representation.""" + if len(records) == 1: + _, context_records_df = next(iter(records.items())) + + if context_records_df is not None: + return context_records_df.to_csv( + index=False, escapechar="\\", sep=column_delimiter + ) + return "" + async def stream_search( self, query: str, diff --git a/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py b/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py index b91272d164..1f61840215 100644 --- a/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py +++ b/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py @@ -84,7 +84,25 @@ def __init__( self.tokenizer = tokenizer or get_tokenizer() self.embedding_vectorstore_key = embedding_vectorstore_key - def build_context( + def validate_context( + self, + include_entity_names: list[str] | None = None, + exclude_entity_names: list[str] | None = None, + text_unit_prop: float = 0.5, + community_prop: float = 0.25, + ): + """Validate context building parameters.""" + if include_entity_names is None: + include_entity_names = [] + if exclude_entity_names is None: + exclude_entity_names = [] + if community_prop + text_unit_prop > 1: + value_error = ( + "The sum of community_prop and text_unit_prop should not exceed 1." + ) + raise ValueError(value_error) + + def build_context_chunks( self, query: str, conversation_history: ConversationHistory | None = None, @@ -108,21 +126,145 @@ def build_context( community_context_name: str = "Reports", column_delimiter: str = "|", **kwargs: dict[str, Any], - ) -> ContextBuilderResult: + ) -> str: """ Build data context for local search prompt. Build a context by combining community reports and entity/relationship/covariate tables, and text units using a predefined ratio set by summary_prop. """ - if include_entity_names is None: - include_entity_names = [] - if exclude_entity_names is None: - exclude_entity_names = [] - if community_prop + text_unit_prop > 1: - value_error = ( - "The sum of community_prop and text_unit_prop should not exceed 1." + self.validate_context( + include_entity_names=include_entity_names, + exclude_entity_names=exclude_entity_names, + text_unit_prop=text_unit_prop, + community_prop=community_prop, + ) + + # map user query to entities + # if there is conversation history, attached the previous user questions to the current query + if conversation_history: + pre_user_questions = "\n".join( + conversation_history.get_user_turns(conversation_history_max_turns) ) - raise ValueError(value_error) + query = f"{query}\n{pre_user_questions}" + + selected_entities = map_query_to_entities( + query=query, + text_embedding_vectorstore=self.entity_text_embeddings, + text_embedder=self.text_embedder, + all_entities_dict=self.entities, + embedding_vectorstore_key=self.embedding_vectorstore_key, + include_entity_names=include_entity_names, + exclude_entity_names=exclude_entity_names, + k=top_k_mapped_entities, + oversample_scaler=2, + ) + + # build context + final_context = list[str]() + final_context_data = dict[str, pd.DataFrame]() + + if conversation_history: + # build conversation history context + ( + conversation_history_context, + conversation_history_context_data, + ) = conversation_history.build_context( + include_user_turns_only=conversation_history_user_turns_only, + max_qa_turns=conversation_history_max_turns, + column_delimiter=column_delimiter, + max_context_tokens=max_context_tokens, + recency_bias=False, + ) + if conversation_history_context.strip() != "": + final_context.append(conversation_history_context) + final_context_data = conversation_history_context_data + max_context_tokens = max_context_tokens - len( + self.tokenizer.encode(conversation_history_context) + ) + + # build community context + community_tokens = max(int(max_context_tokens * community_prop), 0) + community_context, community_context_data = self._build_community_context( + selected_entities=selected_entities, + max_context_tokens=community_tokens, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + return_candidate_context=return_candidate_context, + context_name=community_context_name, + ) + if community_context.strip() != "": + final_context.append(community_context) + final_context_data = {**final_context_data, **community_context_data} + + # build local (i.e. entity-relationship-covariate) context + local_prop = 1 - community_prop - text_unit_prop + local_tokens = max(int(max_context_tokens * local_prop), 0) + local_context, local_context_data = self._build_local_context( + selected_entities=selected_entities, + max_context_tokens=local_tokens, + include_entity_rank=include_entity_rank, + rank_description=rank_description, + include_relationship_weight=include_relationship_weight, + top_k_relationships=top_k_relationships, + relationship_ranking_attribute=relationship_ranking_attribute, + return_candidate_context=return_candidate_context, + column_delimiter=column_delimiter, + ) + if local_context.strip() != "": + final_context.append(str(local_context)) + final_context_data = {**final_context_data, **local_context_data} + + text_unit_tokens = max(int(max_context_tokens * text_unit_prop), 0) + text_unit_context, text_unit_context_data = self._build_text_unit_context( + selected_entities=selected_entities, + max_context_tokens=text_unit_tokens, + return_candidate_context=return_candidate_context, + ) + + if text_unit_context.strip() != "": + final_context.append(text_unit_context) + final_context_data = {**final_context_data, **text_unit_context_data} + + return "\n\n".join(final_context) + + def build_context_records( + self, + query: str, + conversation_history: ConversationHistory | None = None, + include_entity_names: list[str] | None = None, + exclude_entity_names: list[str] | None = None, + conversation_history_max_turns: int | None = 5, + conversation_history_user_turns_only: bool = True, + max_context_tokens: int = 8000, + text_unit_prop: float = 0.5, + community_prop: float = 0.25, + top_k_mapped_entities: int = 10, + top_k_relationships: int = 10, + include_community_rank: bool = False, + include_entity_rank: bool = False, + rank_description: str = "number of relationships", + include_relationship_weight: bool = False, + relationship_ranking_attribute: str = "rank", + return_candidate_context: bool = False, + use_community_summary: bool = False, + min_community_rank: int = 0, + community_context_name: str = "Reports", + column_delimiter: str = "|", + **kwargs: dict[str, Any], + ) -> dict[str, pd.DataFrame]: + """ + Build data context for local search prompt. + + Build a context by combining community reports and entity/relationship/covariate tables, and text units using a predefined ratio set by summary_prop. + """ + self.validate_context( + include_entity_names=include_entity_names, + exclude_entity_names=exclude_entity_names, + text_unit_prop=text_unit_prop, + community_prop=community_prop, + ) # map user query to entities # if there is conversation history, attached the previous user questions to the current query @@ -212,9 +354,22 @@ def build_context( final_context.append(text_unit_context) final_context_data = {**final_context_data, **text_unit_context_data} + return final_context_data + + def build_context( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> ContextBuilderResult: + """ + Build data context for local search prompt. + + Build a context by combining community reports and entity/relationship/covariate tables, and text units using a predefined ratio set by summary_prop. + """ return ContextBuilderResult( - context_chunks="\n\n".join(final_context), - context_records=final_context_data, + context_chunks=self.build_context_chunks(query=query, **kwargs), + context_records=self.build_context_records(query=query, **kwargs), ) def _build_community_context( diff --git a/packages/graphrag/graphrag/query/structured_search/local_search/search.py b/packages/graphrag/graphrag/query/structured_search/local_search/search.py index fdd72949da..24827d4396 100644 --- a/packages/graphrag/graphrag/query/structured_search/local_search/search.py +++ b/packages/graphrag/graphrag/query/structured_search/local_search/search.py @@ -48,6 +48,17 @@ def __init__( self.callbacks = callbacks or [] self.response_type = response_type + async def format_records(self, records, column_delimiter="|") -> str | list[str]: + """Format context records into a string representation.""" + if len(records) == 1: + _, context_records_df = next(iter(records.items())) + + if context_records_df is not None: + return context_records_df.to_csv( + index=False, escapechar="\\", sep=column_delimiter + ) + return "" + async def search( self, query: str, diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index d479120ce2..561f54108b 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -4,15 +4,16 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS -from graphrag.index.operations.summarize_communities.community_reports_extractor import ( - CommunityReportResponse, - FindingModel, -) from graphrag.index.workflows.create_community_reports import ( run_workflow, ) from graphrag.utils.storage import load_table_from_storage +from graphrag.index.operations.summarize_communities.community_reports_extractor import ( + CommunityReportResponse, + FindingModel, +) + from .util import ( DEFAULT_MODEL_CONFIG, compare_outputs, diff --git a/unified-search-app/app/app_logic.py b/unified-search-app/app/app_logic.py index dc64e0e77c..a573b9daa5 100644 --- a/unified-search-app/app/app_logic.py +++ b/unified-search-app/app/app_logic.py @@ -7,7 +7,6 @@ import logging from typing import TYPE_CHECKING -import graphrag.api as api import streamlit as st from knowledge_loader.data_sources.loader import ( create_datasource, @@ -18,6 +17,8 @@ from state.session_variables import SessionVariables from ui.search import display_search_result +import graphrag.api as api + if TYPE_CHECKING: import pandas as pd