diff --git a/.semversioner/next-release/minor-20250325000101658359.json b/.semversioner/next-release/minor-20250325000101658359.json new file mode 100644 index 0000000000..d525e08490 --- /dev/null +++ b/.semversioner/next-release/minor-20250325000101658359.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Support OpenAI reasoning models." +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 3977ed5820..cea7ba8ea3 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -41,13 +41,7 @@ class BasicSearchDefaults: """Default values for basic search.""" prompt: None = None - text_unit_prop: float = 0.5 - conversation_history_max_turns: int = 5 - temperature: float = 0 - top_p: float = 1 - n: int = 1 - max_tokens: int = 12_000 - llm_max_tokens: int = 2000 + k: int = 10 chat_model_id: str = DEFAULT_CHAT_MODEL_ID embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID @@ -104,13 +98,10 @@ class DriftSearchDefaults: prompt: None = None reduce_prompt: None = None - temperature: float = 0 - top_p: float = 1 - n: int = 1 - max_tokens: int = 12_000 data_max_tokens: int = 12_000 - reduce_max_tokens: int = 2_000 + reduce_max_tokens: None = None reduce_temperature: float = 0 + reduce_max_completion_tokens: None = None concurrency: int = 32 drift_k_followups: int = 20 primer_folds: int = 5 @@ -124,7 +115,8 @@ class DriftSearchDefaults: local_search_temperature: float = 0 local_search_top_p: float = 1 local_search_n: int = 1 - local_search_llm_max_gen_tokens: int = 4_096 + local_search_llm_max_gen_tokens = None + local_search_llm_max_gen_completion_tokens = None chat_model_id: str = DEFAULT_CHAT_MODEL_ID embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID @@ -168,7 +160,6 @@ class ExtractClaimsDefaults: ) max_gleanings: int = 1 strategy: None = None - encoding_model: None = None model_id: str = DEFAULT_CHAT_MODEL_ID @@ -182,7 +173,6 @@ class ExtractGraphDefaults: ) max_gleanings: int = 1 strategy: None = None - encoding_model: None = None model_id: str = DEFAULT_CHAT_MODEL_ID @@ -228,20 +218,14 @@ class GlobalSearchDefaults: map_prompt: None = None reduce_prompt: None = None knowledge_prompt: None = None - temperature: float = 0 - top_p: float = 1 - n: int = 1 - max_tokens: int = 12_000 + max_context_tokens: int = 12_000 data_max_tokens: int = 12_000 - map_max_tokens: int = 1000 - reduce_max_tokens: int = 2000 - concurrency: int = 32 - dynamic_search_llm: str = "gpt-4o-mini" + map_max_length: int = 1000 + reduce_max_length: int = 2000 dynamic_search_threshold: int = 1 dynamic_search_keep_parent: bool = False dynamic_search_num_repeats: int = 1 dynamic_search_use_summary: bool = False - dynamic_search_concurrent_coroutines: int = 16 dynamic_search_max_level: int = 2 chat_model_id: str = DEFAULT_CHAT_MODEL_ID @@ -271,8 +255,10 @@ class LanguageModelDefaults: api_key: None = None auth_type = AuthType.APIKey encoding_model: str = "" - max_tokens: int = 4000 + max_tokens: int | None = None temperature: float = 0 + max_completion_tokens: int | None = None + reasoning_effort: str | None = None top_p: float = 1 n: int = 1 frequency_penalty: float = 0.0 @@ -305,11 +291,7 @@ class LocalSearchDefaults: conversation_history_max_turns: int = 5 top_k_entities: int = 10 top_k_relationships: int = 10 - temperature: float = 0 - top_p: float = 1 - n: int = 1 - max_tokens: int = 12_000 - llm_max_tokens: int = 2000 + max_context_tokens: int = 12_000 chat_model_id: str = DEFAULT_CHAT_MODEL_ID embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID @@ -364,6 +346,7 @@ class SummarizeDescriptionsDefaults: prompt: None = None max_length: int = 500 + max_input_tokens: int = 4_000 strategy: None = None model_id: str = DEFAULT_CHAT_MODEL_ID diff --git a/graphrag/config/models/basic_search_config.py b/graphrag/config/models/basic_search_config.py index e1bdfbdcfe..8221cd3ff5 100644 --- a/graphrag/config/models/basic_search_config.py +++ b/graphrag/config/models/basic_search_config.py @@ -23,31 +23,7 @@ class BasicSearchConfig(BaseModel): description="The model ID to use for text embeddings.", default=graphrag_config_defaults.basic_search.embedding_model_id, ) - text_unit_prop: float = Field( - description="The text unit proportion.", - default=graphrag_config_defaults.basic_search.text_unit_prop, - ) - conversation_history_max_turns: int = Field( - description="The conversation history maximum turns.", - default=graphrag_config_defaults.basic_search.conversation_history_max_turns, - ) - temperature: float = Field( - description="The temperature to use for token generation.", - default=graphrag_config_defaults.basic_search.temperature, - ) - top_p: float = Field( - description="The top-p value to use for token generation.", - default=graphrag_config_defaults.basic_search.top_p, - ) - n: int = Field( - description="The number of completions to generate.", - default=graphrag_config_defaults.basic_search.n, - ) - max_tokens: int = Field( - description="The maximum tokens.", - default=graphrag_config_defaults.basic_search.max_tokens, - ) - llm_max_tokens: int = Field( - description="The LLM maximum tokens.", - default=graphrag_config_defaults.basic_search.llm_max_tokens, + k: int = Field( + description="The number of text units to include in search context.", + default=graphrag_config_defaults.basic_search.k, ) diff --git a/graphrag/config/models/community_reports_config.py b/graphrag/config/models/community_reports_config.py index 0b765390ec..b4e9259489 100644 --- a/graphrag/config/models/community_reports_config.py +++ b/graphrag/config/models/community_reports_config.py @@ -50,7 +50,6 @@ def resolved_strategy( return self.strategy or { "type": CreateCommunityReportsStrategyType.graph_intelligence, "llm": model_config.model_dump(), - "num_threads": model_config.concurrent_requests, "graph_prompt": (Path(root_dir) / self.graph_prompt).read_text( encoding="utf-8" ) diff --git a/graphrag/config/models/drift_search_config.py b/graphrag/config/models/drift_search_config.py index 88c0d35702..a6edf66474 100644 --- a/graphrag/config/models/drift_search_config.py +++ b/graphrag/config/models/drift_search_config.py @@ -27,28 +27,12 @@ class DRIFTSearchConfig(BaseModel): description="The model ID to use for drift search.", default=graphrag_config_defaults.drift_search.embedding_model_id, ) - temperature: float = Field( - description="The temperature to use for token generation.", - default=graphrag_config_defaults.drift_search.temperature, - ) - top_p: float = Field( - description="The top-p value to use for token generation.", - default=graphrag_config_defaults.drift_search.top_p, - ) - n: int = Field( - description="The number of completions to generate.", - default=graphrag_config_defaults.drift_search.n, - ) - max_tokens: int = Field( - description="The maximum context size in tokens.", - default=graphrag_config_defaults.drift_search.max_tokens, - ) data_max_tokens: int = Field( description="The data llm maximum tokens.", default=graphrag_config_defaults.drift_search.data_max_tokens, ) - reduce_max_tokens: int = Field( + reduce_max_tokens: int | None = Field( description="The reduce llm maximum tokens response to produce.", default=graphrag_config_defaults.drift_search.reduce_max_tokens, ) @@ -58,6 +42,11 @@ class DRIFTSearchConfig(BaseModel): default=graphrag_config_defaults.drift_search.reduce_temperature, ) + reduce_max_completion_tokens: int | None = Field( + description="The reduce llm maximum tokens response to produce.", + default=graphrag_config_defaults.drift_search.reduce_max_completion_tokens, + ) + concurrency: int = Field( description="The number of concurrent requests.", default=graphrag_config_defaults.drift_search.concurrency, @@ -123,7 +112,12 @@ class DRIFTSearchConfig(BaseModel): default=graphrag_config_defaults.drift_search.local_search_n, ) - local_search_llm_max_gen_tokens: int = Field( + local_search_llm_max_gen_tokens: int | None = Field( description="The maximum number of generated tokens for the LLM in local search.", default=graphrag_config_defaults.drift_search.local_search_llm_max_gen_tokens, ) + + local_search_llm_max_gen_completion_tokens: int | None = Field( + description="The maximum number of generated tokens for the LLM in local search.", + default=graphrag_config_defaults.drift_search.local_search_llm_max_gen_completion_tokens, + ) diff --git a/graphrag/config/models/extract_claims_config.py b/graphrag/config/models/extract_claims_config.py index bb6ba6370e..166cc29d4e 100644 --- a/graphrag/config/models/extract_claims_config.py +++ b/graphrag/config/models/extract_claims_config.py @@ -38,10 +38,6 @@ class ClaimExtractionConfig(BaseModel): description="The override strategy to use.", default=graphrag_config_defaults.extract_claims.strategy, ) - encoding_model: str | None = Field( - default=graphrag_config_defaults.extract_claims.encoding_model, - description="The encoding model to use.", - ) def resolved_strategy( self, root_dir: str, model_config: LanguageModelConfig @@ -49,7 +45,6 @@ def resolved_strategy( """Get the resolved claim extraction strategy.""" return self.strategy or { "llm": model_config.model_dump(), - "num_threads": model_config.concurrent_requests, "extraction_prompt": (Path(root_dir) / self.prompt).read_text( encoding="utf-8" ) @@ -57,5 +52,4 @@ def resolved_strategy( else None, "claim_description": self.description, "max_gleanings": self.max_gleanings, - "encoding_name": model_config.encoding_model, } diff --git a/graphrag/config/models/extract_graph_config.py b/graphrag/config/models/extract_graph_config.py index 1ad29cb699..915ff5d8a5 100644 --- a/graphrag/config/models/extract_graph_config.py +++ b/graphrag/config/models/extract_graph_config.py @@ -34,10 +34,6 @@ class ExtractGraphConfig(BaseModel): description="Override the default entity extraction strategy", default=graphrag_config_defaults.extract_graph.strategy, ) - encoding_model: str | None = Field( - default=graphrag_config_defaults.extract_graph.encoding_model, - description="The encoding model to use.", - ) def resolved_strategy( self, root_dir: str, model_config: LanguageModelConfig @@ -50,12 +46,10 @@ def resolved_strategy( return self.strategy or { "type": ExtractEntityStrategyType.graph_intelligence, "llm": model_config.model_dump(), - "num_threads": model_config.concurrent_requests, "extraction_prompt": (Path(root_dir) / self.prompt).read_text( encoding="utf-8" ) if self.prompt else None, "max_gleanings": self.max_gleanings, - "encoding_name": model_config.encoding_model, } diff --git a/graphrag/config/models/global_search_config.py b/graphrag/config/models/global_search_config.py index 210caa1a72..c350efcea6 100644 --- a/graphrag/config/models/global_search_config.py +++ b/graphrag/config/models/global_search_config.py @@ -27,44 +27,24 @@ class GlobalSearchConfig(BaseModel): description="The global search general prompt to use.", default=graphrag_config_defaults.global_search.knowledge_prompt, ) - temperature: float = Field( - description="The temperature to use for token generation.", - default=graphrag_config_defaults.global_search.temperature, - ) - top_p: float = Field( - description="The top-p value to use for token generation.", - default=graphrag_config_defaults.global_search.top_p, - ) - n: int = Field( - description="The number of completions to generate.", - default=graphrag_config_defaults.global_search.n, - ) - max_tokens: int = Field( + max_context_tokens: int = Field( description="The maximum context size in tokens.", - default=graphrag_config_defaults.global_search.max_tokens, + default=graphrag_config_defaults.global_search.max_context_tokens, ) data_max_tokens: int = Field( description="The data llm maximum tokens.", default=graphrag_config_defaults.global_search.data_max_tokens, ) - map_max_tokens: int = Field( - description="The map llm maximum tokens.", - default=graphrag_config_defaults.global_search.map_max_tokens, + map_max_length: int = Field( + description="The map llm maximum response length in words.", + default=graphrag_config_defaults.global_search.map_max_length, ) - reduce_max_tokens: int = Field( - description="The reduce llm maximum tokens.", - default=graphrag_config_defaults.global_search.reduce_max_tokens, - ) - concurrency: int = Field( - description="The number of concurrent requests.", - default=graphrag_config_defaults.global_search.concurrency, + reduce_max_length: int = Field( + description="The reduce llm maximum response length in words.", + default=graphrag_config_defaults.global_search.reduce_max_length, ) # configurations for dynamic community selection - dynamic_search_llm: str = Field( - description="LLM model to use for dynamic community selection", - default=graphrag_config_defaults.global_search.dynamic_search_llm, - ) dynamic_search_threshold: int = Field( description="Rating threshold in include a community report", default=graphrag_config_defaults.global_search.dynamic_search_threshold, @@ -81,10 +61,6 @@ class GlobalSearchConfig(BaseModel): description="Use community summary instead of full_context", default=graphrag_config_defaults.global_search.dynamic_search_use_summary, ) - dynamic_search_concurrent_coroutines: int = Field( - description="Number of concurrent coroutines to rate community reports", - default=graphrag_config_defaults.global_search.dynamic_search_concurrent_coroutines, - ) dynamic_search_max_level: int = Field( description="The maximum level of community hierarchy to consider if none of the processed communities are relevant", default=graphrag_config_defaults.global_search.dynamic_search_max_level, diff --git a/graphrag/config/models/language_model_config.py b/graphrag/config/models/language_model_config.py index 5eadfa6eeb..375fcd177a 100644 --- a/graphrag/config/models/language_model_config.py +++ b/graphrag/config/models/language_model_config.py @@ -223,7 +223,7 @@ def _validate_deployment_name(self) -> None: default=language_model_defaults.responses, description="Static responses to use in mock mode.", ) - max_tokens: int = Field( + max_tokens: int | None = Field( description="The maximum number of tokens to generate.", default=language_model_defaults.max_tokens, ) @@ -231,6 +231,14 @@ def _validate_deployment_name(self) -> None: description="The temperature to use for token generation.", default=language_model_defaults.temperature, ) + max_completion_tokens: int | None = Field( + description="The maximum number of tokens to consume. This includes reasoning tokens for the o* reasoning models.", + default=language_model_defaults.max_completion_tokens, + ) + reasoning_effort: str | None = Field( + description="Level of effort OpenAI reasoning models should expend. Supported options are 'low', 'medium', 'high'; and OAI defaults to 'medium'.", + default=language_model_defaults.reasoning_effort, + ) top_p: float = Field( description="The top-p value to use for token generation.", default=language_model_defaults.top_p, diff --git a/graphrag/config/models/local_search_config.py b/graphrag/config/models/local_search_config.py index 97d818b238..4cf31ffe0e 100644 --- a/graphrag/config/models/local_search_config.py +++ b/graphrag/config/models/local_search_config.py @@ -43,23 +43,7 @@ class LocalSearchConfig(BaseModel): description="The top k mapped relations.", default=graphrag_config_defaults.local_search.top_k_relationships, ) - temperature: float = Field( - description="The temperature to use for token generation.", - default=graphrag_config_defaults.local_search.temperature, - ) - top_p: float = Field( - description="The top-p value to use for token generation.", - default=graphrag_config_defaults.local_search.top_p, - ) - n: int = Field( - description="The number of completions to generate.", - default=graphrag_config_defaults.local_search.n, - ) - max_tokens: int = Field( + max_context_tokens: int = Field( description="The maximum tokens.", - default=graphrag_config_defaults.local_search.max_tokens, - ) - llm_max_tokens: int = Field( - description="The LLM maximum tokens.", - default=graphrag_config_defaults.local_search.llm_max_tokens, + default=graphrag_config_defaults.local_search.max_context_tokens, ) diff --git a/graphrag/config/models/summarize_descriptions_config.py b/graphrag/config/models/summarize_descriptions_config.py index 7237da136c..3d67de0d54 100644 --- a/graphrag/config/models/summarize_descriptions_config.py +++ b/graphrag/config/models/summarize_descriptions_config.py @@ -26,6 +26,10 @@ class SummarizeDescriptionsConfig(BaseModel): description="The description summarization maximum length.", default=graphrag_config_defaults.summarize_descriptions.max_length, ) + max_input_tokens: int = Field( + description="Maximum tokens to submit from the input entity descriptions.", + default=graphrag_config_defaults.summarize_descriptions.max_input_tokens, + ) strategy: dict | None = Field( description="The override strategy to use.", default=graphrag_config_defaults.summarize_descriptions.strategy, @@ -42,11 +46,11 @@ def resolved_strategy( return self.strategy or { "type": SummarizeStrategyType.graph_intelligence, "llm": model_config.model_dump(), - "num_threads": model_config.concurrent_requests, "summarize_prompt": (Path(root_dir) / self.prompt).read_text( encoding="utf-8" ) if self.prompt else None, "max_summary_length": self.max_length, + "max_input_tokens": self.max_input_tokens, } diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 0b294b3587..935644b025 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -45,37 +45,7 @@ async def embed_text( id_column: str = "id", title_column: str | None = None, ): - """ - Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector. - - ## Usage - ```yaml - args: - column: text # The name of the column containing the text to embed, this can either be a column with text, or a column with a list[tuple[doc_id, str]] - to: embedding # The name of the column to output the embedding to - strategy: # See strategies section below - ``` - - ## Strategies - The text embed operation uses a strategy to embed the text. The strategy is an object which defines the strategy to use. The following strategies are available: - - ### openai - This strategy uses openai to embed a piece of text. In particular it uses a LLM to embed a piece of text. The strategy config is as follows: - - ```yaml - strategy: - type: openai - llm: # The configuration for the LLM - type: openai_embedding # the type of llm to use, available options are: openai_embedding, azure_openai_embedding - api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai - model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai - max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai - organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai - vector_store: # The optional configuration for the vector store - type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb - <...> - ``` - """ + """Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.""" vector_store_config = strategy.get("vector_store") if vector_store_config: diff --git a/graphrag/index/operations/embed_text/strategies/openai.py b/graphrag/index/operations/embed_text/strategies/openai.py index 56fe780922..b5dee44335 100644 --- a/graphrag/index/operations/embed_text/strategies/openai.py +++ b/graphrag/index/operations/embed_text/strategies/openai.py @@ -55,7 +55,7 @@ async def run( splitter, ) log.info( - "embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, max_tokens=%d", + "embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d", len(input), len(texts), len(text_batches), diff --git a/graphrag/index/operations/extract_covariates/claim_extractor.py b/graphrag/index/operations/extract_covariates/claim_extractor.py index 04d93b4d68..6ddfe503c0 100644 --- a/graphrag/index/operations/extract_covariates/claim_extractor.py +++ b/graphrag/index/operations/extract_covariates/claim_extractor.py @@ -8,9 +8,7 @@ from dataclasses import dataclass from typing import Any -import tiktoken - -from graphrag.config.defaults import ENCODING_MODEL, graphrag_config_defaults +from graphrag.config.defaults import graphrag_config_defaults from graphrag.index.typing.error_handler import ErrorHandlerFn from graphrag.language_model.protocol.base import ChatModel from graphrag.prompts.index.extract_claims import ( @@ -48,7 +46,6 @@ class ClaimExtractor: _completion_delimiter_key: str _max_gleanings: int _on_error: ErrorHandlerFn - _loop_args: dict[str, Any] def __init__( self, @@ -61,7 +58,6 @@ def __init__( tuple_delimiter_key: str | None = None, record_delimiter_key: str | None = None, completion_delimiter_key: str | None = None, - encoding_model: str | None = None, max_gleanings: int | None = None, on_error: ErrorHandlerFn | None = None, ): @@ -88,12 +84,6 @@ def __init__( ) self._on_error = on_error or (lambda _e, _s, _d: None) - # Construct the looping arguments - encoding = tiktoken.get_encoding(encoding_model or ENCODING_MODEL) - yes = f"{encoding.encode('Y')[0]}" - no = f"{encoding.encode('N')[0]}" - self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} - async def __call__( self, inputs: dict[str, Any], prompt_variables: dict | None = None ) -> ClaimExtractorResult: @@ -175,30 +165,32 @@ async def _process_document( results = response.output.content or "" claims = results.strip().removesuffix(completion_delimiter) - # Repeat to ensure we maximize entity count - for i in range(self._max_gleanings): - response = await self._model.achat( - CONTINUE_PROMPT, - name=f"extract-continuation-{i}", - history=response.history, - ) - extension = response.output.content or "" - claims += record_delimiter + extension.strip().removesuffix( - completion_delimiter - ) + # if gleanings are specified, enter a loop to extract more claims + # there are two exit criteria: (a) we hit the configured max, (b) the model says there are no more claims + if self._max_gleanings > 0: + for i in range(self._max_gleanings): + response = await self._model.achat( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + extension = response.output.content or "" + claims += record_delimiter + extension.strip().removesuffix( + completion_delimiter + ) - # If this isn't the last loop, check to see if we should continue - if i >= self._max_gleanings - 1: - break + # If this isn't the last loop, check to see if we should continue + if i >= self._max_gleanings - 1: + break - response = await self._model.achat( - LOOP_PROMPT, - name=f"extract-loopcheck-{i}", - history=response.history, - model_parameters=self._loop_args, - ) - if response.output.content != "Y": - break + response = await self._model.achat( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + ) + + if response.output.content != "Y": + break return self._parse_claim_tuples(results, prompt_args) diff --git a/graphrag/index/operations/extract_covariates/extract_covariates.py b/graphrag/index/operations/extract_covariates/extract_covariates.py index 5c18be2505..22e01d1f4f 100644 --- a/graphrag/index/operations/extract_covariates/extract_covariates.py +++ b/graphrag/index/operations/extract_covariates/extract_covariates.py @@ -109,13 +109,11 @@ async def run_extract_claims( tuple_delimiter = strategy_config.get("tuple_delimiter") record_delimiter = strategy_config.get("record_delimiter") completion_delimiter = strategy_config.get("completion_delimiter") - encoding_model = strategy_config.get("encoding_name") extractor = ClaimExtractor( model_invoker=llm, extraction_prompt=extraction_prompt, max_gleanings=max_gleanings, - encoding_model=encoding_model, on_error=lambda e, s, d: ( callbacks.error("Claim Extraction Error", e, s, d) if callbacks else None ), diff --git a/graphrag/index/operations/extract_graph/extract_graph.py b/graphrag/index/operations/extract_graph/extract_graph.py index 1f0c26a066..98e7cbc9f9 100644 --- a/graphrag/index/operations/extract_graph/extract_graph.py +++ b/graphrag/index/operations/extract_graph/extract_graph.py @@ -35,56 +35,7 @@ async def extract_graph( entity_types=DEFAULT_ENTITY_TYPES, num_threads: int = 4, ) -> tuple[pd.DataFrame, pd.DataFrame]: - """ - Extract entities from a piece of text. - - ## Usage - ```yaml - args: - column: the_document_text_column_to_extract_graph_from - id_column: the_column_with_the_unique_id_for_each_row - to: the_column_to_output_the_entities_to - strategy: , see strategies section below - summarize_descriptions: true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */ - entity_types: - - list - - of - - entity - - types - - to - - extract - ``` - - ## Strategies - The entity extract verb uses a strategy to extract entities from a document. The strategy is a json object which defines the strategy to use. The following strategies are available: - - ### graph_intelligence - This strategy uses the [graph_intelligence] library to extract entities from a document. In particular it uses a LLM to extract entities from a piece of text. The strategy config is as follows: - - ```yml - strategy: - type: graph_intelligence - extraction_prompt: !include ./extract_graph_prompt.txt # Optional, the prompt to use for extraction - completion_delimiter: "<|COMPLETE|>" # Optional, the delimiter to use for the LLM to mark completion - tuple_delimiter: "<|>" # Optional, the delimiter to use for the LLM to mark a tuple - record_delimiter: "##" # Optional, the delimiter to use for the LLM to mark a record - - encoding_name: cl100k_base # Optional, The encoding to use for the LLM with gleanings - - llm: # The configuration for the LLM - type: openai # the type of llm to use, available options are: openai, azure, openai_chat, azure_openai_chat. The last two being chat based LLMs. - api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai - model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai - max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai - organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai - - # if using azure flavor - api_base: !ENV ${GRAPHRAG_OPENAI_API_BASE} # The api base to use for azure - api_version: !ENV ${GRAPHRAG_OPENAI_API_VERSION} # The api version to use for azure - proxy: !ENV ${GRAPHRAG_OPENAI_PROXY} # The proxy to use for azure - - ``` - """ + """Extract a graph from a piece of text using a language model.""" log.debug("entity_extract strategy=%s", strategy) if entity_types is None: entity_types = DEFAULT_ENTITY_TYPES diff --git a/graphrag/index/operations/extract_graph/graph_extractor.py b/graphrag/index/operations/extract_graph/graph_extractor.py index 08f5b9553a..f7601f2601 100644 --- a/graphrag/index/operations/extract_graph/graph_extractor.py +++ b/graphrag/index/operations/extract_graph/graph_extractor.py @@ -11,9 +11,8 @@ from typing import Any import networkx as nx -import tiktoken -from graphrag.config.defaults import ENCODING_MODEL, graphrag_config_defaults +from graphrag.config.defaults import graphrag_config_defaults from graphrag.index.typing.error_handler import ErrorHandlerFn from graphrag.index.utils.string import clean_str from graphrag.language_model.protocol.base import ChatModel @@ -53,7 +52,6 @@ class GraphExtractor: _input_descriptions_key: str _extraction_prompt: str _summarization_prompt: str - _loop_args: dict[str, Any] _max_gleanings: int _on_error: ErrorHandlerFn @@ -67,7 +65,6 @@ def __init__( completion_delimiter_key: str | None = None, prompt: str | None = None, join_descriptions=True, - encoding_model: str | None = None, max_gleanings: int | None = None, on_error: ErrorHandlerFn | None = None, ): @@ -90,12 +87,6 @@ def __init__( ) self._on_error = on_error or (lambda _e, _s, _d: None) - # Construct the looping arguments - encoding = tiktoken.get_encoding(encoding_model or ENCODING_MODEL) - yes = f"{encoding.encode('Y')[0]}" - no = f"{encoding.encode('N')[0]}" - self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} - async def __call__( self, texts: list[str], prompt_variables: dict[str, Any] | None = None ) -> GraphExtractionResult: @@ -160,28 +151,28 @@ async def _process_document( ) results = response.output.content or "" - # Repeat to ensure we maximize entity count - for i in range(self._max_gleanings): - response = await self._model.achat( - CONTINUE_PROMPT, - name=f"extract-continuation-{i}", - history=response.history, - ) - results += response.output.content or "" - - # if this is the final glean, don't bother updating the continuation flag - if i >= self._max_gleanings - 1: - break + # if gleanings are specified, enter a loop to extract more entities + # there are two exit criteria: (a) we hit the configured max, (b) the model says there are no more entities + if self._max_gleanings > 0: + for i in range(self._max_gleanings): + response = await self._model.achat( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + results += response.output.content or "" - response = await self._model.achat( - LOOP_PROMPT, - name=f"extract-loopcheck-{i}", - history=response.history, - model_parameters=self._loop_args, - ) + # if this is the final glean, don't bother updating the continuation flag + if i >= self._max_gleanings - 1: + break - if response.output.content != "Y": - break + response = await self._model.achat( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + ) + if response.output.content != "Y": + break return results diff --git a/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py b/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py index 6632e4c736..9bb6a88db6 100644 --- a/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py +++ b/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py @@ -53,7 +53,6 @@ async def run_extract_graph( record_delimiter = args.get("record_delimiter", None) completion_delimiter = args.get("completion_delimiter", None) extraction_prompt = args.get("extraction_prompt", None) - encoding_model = args.get("encoding_name", None) max_gleanings = args.get( "max_gleanings", graphrag_config_defaults.extract_graph.max_gleanings ) @@ -61,7 +60,6 @@ async def run_extract_graph( extractor = GraphExtractor( model_invoker=model, prompt=extraction_prompt, - encoding_model=encoding_model, max_gleanings=max_gleanings, on_error=lambda e, s, d: ( callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None diff --git a/graphrag/index/operations/summarize_communities/build_mixed_context.py b/graphrag/index/operations/summarize_communities/build_mixed_context.py index 846e9629e1..6c1893ee54 100644 --- a/graphrag/index/operations/summarize_communities/build_mixed_context.py +++ b/graphrag/index/operations/summarize_communities/build_mixed_context.py @@ -11,7 +11,7 @@ from graphrag.query.llm.text_utils import num_tokens -def build_mixed_context(context: list[dict], max_tokens: int) -> str: +def build_mixed_context(context: list[dict], max_context_tokens: int) -> str: """ Build parent context by concatenating all sub-communities' contexts. @@ -47,7 +47,7 @@ def build_mixed_context(context: list[dict], max_tokens: int) -> str: local_context=remaining_local_context + final_local_contexts, sub_community_reports=substitute_reports, ) - if num_tokens(new_context_string) <= max_tokens: + if num_tokens(new_context_string) <= max_context_tokens: exceeded_limit = False context_string = new_context_string break @@ -63,7 +63,7 @@ def build_mixed_context(context: list[dict], max_tokens: int) -> str: new_context_string = pd.DataFrame(substitute_reports).to_csv( index=False, sep="," ) - if num_tokens(new_context_string) > max_tokens: + if num_tokens(new_context_string) > max_context_tokens: break context_string = new_context_string diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor.py b/graphrag/index/operations/summarize_communities/community_reports_extractor.py index d7dabb2468..73ac0ec9e2 100644 --- a/graphrag/index/operations/summarize_communities/community_reports_extractor.py +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor.py @@ -6,7 +6,6 @@ import logging import traceback from dataclasses import dataclass -from typing import Any from pydantic import BaseModel, Field @@ -16,6 +15,10 @@ log = logging.getLogger(__name__) +# these tokens are used in the prompt +INPUT_TEXT_KEY = "input_text" +MAX_LENGTH_KEY = "max_report_length" + class FindingModel(BaseModel): """A model for the expected LLM response shape.""" @@ -48,7 +51,6 @@ class CommunityReportsExtractor: """Community reports extractor class definition.""" _model: ChatModel - _input_text_key: str _extraction_prompt: str _output_formatter_prompt: str _on_error: ErrorHandlerFn @@ -57,32 +59,29 @@ class CommunityReportsExtractor: def __init__( self, model_invoker: ChatModel, - input_text_key: str | None = None, extraction_prompt: str | None = None, on_error: ErrorHandlerFn | None = None, max_report_length: int | None = None, ): """Init method definition.""" self._model = model_invoker - self._input_text_key = input_text_key or "input_text" self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT self._on_error = on_error or (lambda _e, _s, _d: None) self._max_report_length = max_report_length or 1500 - async def __call__(self, inputs: dict[str, Any]): + async def __call__(self, input_text: str): """Call method definition.""" output = None try: - input_text = inputs[self._input_text_key] - prompt = self._extraction_prompt.replace( - "{" + self._input_text_key + "}", input_text - ) + prompt = self._extraction_prompt.format(**{ + INPUT_TEXT_KEY: input_text, + MAX_LENGTH_KEY: str(self._max_report_length), + }) response = await self._model.achat( prompt, json=True, # Leaving this as True to avoid creating new cache entries name="create_community_report", json_model=CommunityReportResponse, # A model is required when using json mode - model_parameters={"max_tokens": self._max_report_length}, ) output = response.parsed_response diff --git a/graphrag/index/operations/summarize_communities/graph_context/context_builder.py b/graphrag/index/operations/summarize_communities/graph_context/context_builder.py index e171ce40b6..8c33fe8269 100644 --- a/graphrag/index/operations/summarize_communities/graph_context/context_builder.py +++ b/graphrag/index/operations/summarize_communities/graph_context/context_builder.py @@ -40,7 +40,7 @@ def build_local_context( edges, claims, callbacks: WorkflowCallbacks, - max_tokens: int = 16_000, + max_context_tokens: int = 16_000, ): """Prep communities for report generation.""" levels = get_levels(nodes, schemas.COMMUNITY_LEVEL) @@ -49,7 +49,7 @@ def build_local_context( for level in progress_iterable(levels, callbacks.progress, len(levels)): communities_at_level_df = _prepare_reports_at_level( - nodes, edges, claims, level, max_tokens + nodes, edges, claims, level, max_context_tokens ) communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level @@ -64,7 +64,7 @@ def _prepare_reports_at_level( edge_df: pd.DataFrame, claim_df: pd.DataFrame | None, level: int, - max_tokens: int = 16_000, + max_context_tokens: int = 16_000, ) -> pd.DataFrame: """Prepare reports at a given level.""" # Filter and prepare node details @@ -181,7 +181,7 @@ def _prepare_reports_at_level( # Generate community-level context strings using vectorized batch processing return parallel_sort_context_batch( community_df, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, ) @@ -190,7 +190,7 @@ def build_level_context( community_hierarchy_df: pd.DataFrame, local_context_df: pd.DataFrame, level: int, - max_tokens: int, + max_context_tokens: int, ) -> pd.DataFrame: """ Prep context for each community in a given level. @@ -219,7 +219,7 @@ def build_level_context( if report_df is None or report_df.empty: invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context( - invalid_context_df, max_tokens + invalid_context_df, max_context_tokens ) invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[ :, schemas.CONTEXT_STRING @@ -233,14 +233,18 @@ def build_level_context( # first get local context and report (if available) for each sub-community sub_context_df = _get_subcontext_df(level + 1, report_df, local_context_df) community_df = _get_community_df( - level, invalid_context_df, sub_context_df, community_hierarchy_df, max_tokens + level, + invalid_context_df, + sub_context_df, + community_hierarchy_df, + max_context_tokens, ) # handle any remaining invalid records that can't be subsituted with sub-community reports # this should be rare, but if it happens, we will just trim the local context to fit the limit remaining_df = _antijoin_reports(invalid_context_df, community_df) remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context( - remaining_df, max_tokens + remaining_df, max_context_tokens ) result = union(valid_context_df, community_df, remaining_df) @@ -265,17 +269,19 @@ def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame: return antijoin(df, reports, schemas.COMMUNITY_ID) -def _sort_and_trim_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: +def _sort_and_trim_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series: """Sort and trim context to fit the limit.""" series = cast("pd.Series", df[schemas.ALL_CONTEXT]) - return transform_series(series, lambda x: sort_context(x, max_tokens=max_tokens)) + return transform_series( + series, lambda x: sort_context(x, max_context_tokens=max_context_tokens) + ) -def _build_mixed_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: +def _build_mixed_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series: """Sort and trim context to fit the limit.""" series = cast("pd.Series", df[schemas.ALL_CONTEXT]) return transform_series( - series, lambda x: build_mixed_context(x, max_tokens=max_tokens) + series, lambda x: build_mixed_context(x, max_context_tokens=max_context_tokens) ) @@ -297,7 +303,7 @@ def _get_community_df( invalid_context_df: pd.DataFrame, sub_context_df: pd.DataFrame, community_hierarchy_df: pd.DataFrame, - max_tokens: int, + max_context_tokens: int, ) -> pd.DataFrame: """Get community context for each community.""" # collect all sub communities' contexts for each community @@ -332,7 +338,7 @@ def _get_community_df( .reset_index() ) community_df[schemas.CONTEXT_STRING] = _build_mixed_context( - community_df, max_tokens + community_df, max_context_tokens ) community_df[schemas.COMMUNITY_LEVEL] = level return community_df diff --git a/graphrag/index/operations/summarize_communities/graph_context/sort_context.py b/graphrag/index/operations/summarize_communities/graph_context/sort_context.py index 20d84aaa2c..e822ad313b 100644 --- a/graphrag/index/operations/summarize_communities/graph_context/sort_context.py +++ b/graphrag/index/operations/summarize_communities/graph_context/sort_context.py @@ -11,7 +11,7 @@ def sort_context( local_context: list[dict], sub_community_reports: list[dict] | None = None, - max_tokens: int | None = None, + max_context_tokens: int | None = None, node_name_column: str = schemas.TITLE, node_details_column: str = schemas.NODE_DETAILS, edge_id_column: str = schemas.SHORT_ID, @@ -112,7 +112,7 @@ def _get_context_string( new_context_string = _get_context_string( sorted_nodes, sorted_edges, sorted_claims, sub_community_reports ) - if max_tokens and num_tokens(new_context_string) > max_tokens: + if max_context_tokens and num_tokens(new_context_string) > max_context_tokens: break context_string = new_context_string @@ -122,7 +122,7 @@ def _get_context_string( ) -def parallel_sort_context_batch(community_df, max_tokens, parallel=False): +def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False): """Calculate context using parallelization if enabled.""" if parallel: # Use ThreadPoolExecutor for parallel execution @@ -131,7 +131,7 @@ def parallel_sort_context_batch(community_df, max_tokens, parallel=False): with ThreadPoolExecutor(max_workers=None) as executor: context_strings = list( executor.map( - lambda x: sort_context(x, max_tokens=max_tokens), + lambda x: sort_context(x, max_context_tokens=max_context_tokens), community_df[schemas.ALL_CONTEXT], ) ) @@ -140,7 +140,9 @@ def parallel_sort_context_batch(community_df, max_tokens, parallel=False): else: # Assign context strings directly to the DataFrame community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply( - lambda context_list: sort_context(context_list, max_tokens=max_tokens) + lambda context_list: sort_context( + context_list, max_context_tokens=max_context_tokens + ) ) # Calculate other columns @@ -148,7 +150,7 @@ def parallel_sort_context_batch(community_df, max_tokens, parallel=False): num_tokens ) community_df[schemas.CONTEXT_EXCEED_FLAG] = ( - community_df[schemas.CONTEXT_SIZE] > max_tokens + community_df[schemas.CONTEXT_SIZE] > max_context_tokens ) return community_df diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py index 430771f542..4a42fbf9d1 100644 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ b/graphrag/index/operations/summarize_communities/strategies.py @@ -66,7 +66,7 @@ async def _run_extractor( try: await rate_limiter.acquire() - results = await extractor({"input_text": input}) + results = await extractor(input) report = results.structured_output if report is None: log.warning("No report found for community: %s", community) diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index 276a2143a3..0afdbe9a73 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -64,7 +64,7 @@ async def summarize_communities( community_hierarchy_df=community_hierarchy, local_context_df=local_contexts, level=level, - max_tokens=max_input_length, + max_context_tokens=max_input_length, ) level_contexts.append(level_context) diff --git a/graphrag/index/operations/summarize_communities/text_unit_context/context_builder.py b/graphrag/index/operations/summarize_communities/text_unit_context/context_builder.py index 54aa72bfaa..95f3621858 100644 --- a/graphrag/index/operations/summarize_communities/text_unit_context/context_builder.py +++ b/graphrag/index/operations/summarize_communities/text_unit_context/context_builder.py @@ -27,7 +27,7 @@ def build_local_context( community_membership_df: pd.DataFrame, text_units_df: pd.DataFrame, node_df: pd.DataFrame, - max_tokens: int = 16000, + max_context_tokens: int = 16000, ) -> pd.DataFrame: """ Prep context data for community report generation using text unit data. @@ -75,7 +75,7 @@ def build_local_context( lambda x: num_tokens(x) ) context_df[schemas.CONTEXT_EXCEED_FLAG] = context_df[schemas.CONTEXT_SIZE].apply( - lambda x: x > max_tokens + lambda x: x > max_context_tokens ) return context_df @@ -86,7 +86,7 @@ def build_level_context( community_hierarchy_df: pd.DataFrame, local_context_df: pd.DataFrame, level: int, - max_tokens: int = 16000, + max_context_tokens: int = 16000, ) -> pd.DataFrame: """ Prep context for each community in a given level. @@ -116,7 +116,7 @@ def build_level_context( invalid_context_df.loc[:, [schemas.CONTEXT_STRING]] = invalid_context_df[ schemas.ALL_CONTEXT - ].apply(lambda x: sort_context(x, max_tokens=max_tokens)) + ].apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens)) invalid_context_df.loc[:, [schemas.CONTEXT_SIZE]] = invalid_context_df[ schemas.CONTEXT_STRING ].apply(lambda x: num_tokens(x)) @@ -199,7 +199,7 @@ def build_level_context( .reset_index() ) community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply( - lambda x: build_mixed_context(x, max_tokens) + lambda x: build_mixed_context(x, max_context_tokens) ) community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply( lambda x: num_tokens(x) @@ -220,7 +220,7 @@ def build_level_context( ) remaining_df[schemas.CONTEXT_STRING] = cast( "pd.DataFrame", remaining_df[schemas.ALL_CONTEXT] - ).apply(lambda x: sort_context(x, max_tokens=max_tokens)) + ).apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens)) remaining_df[schemas.CONTEXT_SIZE] = cast( "pd.DataFrame", remaining_df[schemas.CONTEXT_STRING] ).apply(lambda x: num_tokens(x)) diff --git a/graphrag/index/operations/summarize_communities/text_unit_context/sort_context.py b/graphrag/index/operations/summarize_communities/text_unit_context/sort_context.py index 57e43b8caf..2435dfbdb6 100644 --- a/graphrag/index/operations/summarize_communities/text_unit_context/sort_context.py +++ b/graphrag/index/operations/summarize_communities/text_unit_context/sort_context.py @@ -58,7 +58,7 @@ def get_context_string( def sort_context( local_context: list[dict], sub_community_reports: list[dict] | None = None, - max_tokens: int | None = None, + max_context_tokens: int | None = None, ) -> str: """Sort local context (list of text units) by total degree of associated nodes in descending order.""" sorted_text_units = sorted( @@ -69,11 +69,11 @@ def sort_context( context_string = "" for record in sorted_text_units: current_text_units.append(record) - if max_tokens: + if max_context_tokens: new_context_string = get_context_string( current_text_units, sub_community_reports ) - if num_tokens(new_context_string) > max_tokens: + if num_tokens(new_context_string) > max_context_tokens: break context_string = new_context_string diff --git a/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py b/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py index ba11983ea6..d037cbb318 100644 --- a/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py +++ b/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py @@ -11,10 +11,10 @@ from graphrag.language_model.protocol.base import ChatModel from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT -# Max token size for input prompts -DEFAULT_MAX_INPUT_TOKENS = 4_000 -# Max token count for LLM answers -DEFAULT_MAX_SUMMARY_LENGTH = 500 +# these tokens are used in the prompt +ENTITY_NAME_KEY = "entity_name" +DESCRIPTION_LIST_KEY = "description_list" +MAX_LENGTH_KEY = "max_length" @dataclass @@ -29,8 +29,6 @@ class SummarizeExtractor: """Unipartite graph extractor class definition.""" _model: ChatModel - _entity_name_key: str - _input_descriptions_key: str _summarization_prompt: str _on_error: ErrorHandlerFn _max_summary_length: int @@ -39,23 +37,19 @@ class SummarizeExtractor: def __init__( self, model_invoker: ChatModel, - entity_name_key: str | None = None, - input_descriptions_key: str | None = None, + max_summary_length: int, + max_input_tokens: int, summarization_prompt: str | None = None, on_error: ErrorHandlerFn | None = None, - max_summary_length: int | None = None, - max_input_tokens: int | None = None, ): """Init method definition.""" # TODO: streamline construction self._model = model_invoker - self._entity_name_key = entity_name_key or "entity_name" - self._input_descriptions_key = input_descriptions_key or "description_list" self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT self._on_error = on_error or (lambda _e, _s, _d: None) - self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH - self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS + self._max_summary_length = max_summary_length + self._max_input_tokens = max_input_tokens async def __call__( self, @@ -127,13 +121,13 @@ async def _summarize_descriptions_with_llm( """Summarize descriptions using the LLM.""" response = await self._model.achat( self._summarization_prompt.format(**{ - self._entity_name_key: json.dumps(id, ensure_ascii=False), - self._input_descriptions_key: json.dumps( + ENTITY_NAME_KEY: json.dumps(id, ensure_ascii=False), + DESCRIPTION_LIST_KEY: json.dumps( sorted(descriptions), ensure_ascii=False ), + MAX_LENGTH_KEY: self._max_summary_length, }), name="summarize", - model_parameters={"max_tokens": self._max_summary_length}, ) # Calculate result return str(response.output.content) diff --git a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py index e16a9dc22f..2b95d6b1e5 100644 --- a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py +++ b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py @@ -47,22 +47,18 @@ async def run_summarize_descriptions( """Run the entity extraction chain.""" # Extraction Arguments summarize_prompt = args.get("summarize_prompt", None) - entity_name_key = args.get("entity_name_key", "entity_name") - input_descriptions_key = args.get("input_descriptions_key", "description_list") - max_tokens = args.get("max_tokens", None) - + max_input_tokens = args["max_input_tokens"] + max_summary_length = args["max_summary_length"] extractor = SummarizeExtractor( model_invoker=model, summarization_prompt=summarize_prompt, - entity_name_key=entity_name_key, - input_descriptions_key=input_descriptions_key, on_error=lambda e, stack, details: ( callbacks.error("Entity Extraction Error", e, stack, details) if callbacks else None ), - max_summary_length=args.get("max_summary_length", None), - max_input_tokens=max_tokens, + max_summary_length=max_summary_length, + max_input_tokens=max_input_tokens, ) result = await extractor(id=id, descriptions=descriptions) diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index 25331b9071..86ffb6dd6e 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -28,47 +28,7 @@ async def summarize_descriptions( strategy: dict[str, Any] | None = None, num_threads: int = 4, ) -> tuple[pd.DataFrame, pd.DataFrame]: - """ - Summarize entity and relationship descriptions from an entity graph. - - ## Usage - - To turn this feature ON please set the environment variable `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_ENABLED=True`. - - ### yaml - - ```yaml - args: - strategy: , see strategies section below - ``` - - ## Strategies - - The summarize descriptions verb uses a strategy to summarize descriptions for entities. The strategy is a json object which defines the strategy to use. The following strategies are available: - - ### graph_intelligence - - This strategy uses the [graph_intelligence] library to summarize descriptions for entities. The strategy config is as follows: - - ```yml - strategy: - type: graph_intelligence - summarize_prompt: # Optional, the prompt to use for extraction - - - llm: # The configuration for the LLM - type: openai # the type of llm to use, available options are: openai, azure, openai_chat, azure_openai_chat. The last two being chat based LLMs. - api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai - model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai - max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai - organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai - - # if using azure flavor - api_base: !ENV ${GRAPHRAG_OPENAI_API_BASE} # The api base to use for azure - api_version: !ENV ${GRAPHRAG_OPENAI_API_VERSION} # The api version to use for azure - proxy: !ENV ${GRAPHRAG_OPENAI_PROXY} # The proxy to use for azure - ``` - """ + """Summarize entity and relationship descriptions from an entity graph, using a language model.""" log.debug("summarize_descriptions strategy=%s", strategy) strategy = strategy or {} strategy_exec = load_strategy( diff --git a/graphrag/language_model/protocol/base.py b/graphrag/language_model/protocol/base.py index fc2a0a98c3..74cd38746e 100644 --- a/graphrag/language_model/protocol/base.py +++ b/graphrag/language_model/protocol/base.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator + from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.language_model.response.base import ModelResponse @@ -20,6 +21,9 @@ class EmbeddingModel(Protocol): This protocol defines the methods required for an embedding-based LM. """ + config: LanguageModelConfig + """Passthrough of the config used to create the model instance.""" + async def aembed_batch( self, text_list: list[str], **kwargs: Any ) -> list[list[float]]: @@ -87,6 +91,9 @@ class ChatModel(Protocol): Prompt is always required for the chat method, and any other keyword arguments are forwarded to the Model provider. """ + config: LanguageModelConfig + """Passthrough of the config used to create the model instance.""" + async def achat( self, prompt: str, history: list | None = None, **kwargs: Any ) -> ModelResponse: diff --git a/graphrag/language_model/providers/fnllm/models.py b/graphrag/language_model/providers/fnllm/models.py index 27c04e5e94..fda91c96ba 100644 --- a/graphrag/language_model/providers/fnllm/models.py +++ b/graphrag/language_model/providers/fnllm/models.py @@ -62,6 +62,7 @@ def __init__( cache=model_cache, events=FNLLMEvents(error_handler) if error_handler else None, ) + self.config = config async def achat( self, prompt: str, history: list | None = None, **kwargs @@ -167,6 +168,7 @@ def __init__( cache=model_cache, events=FNLLMEvents(error_handler) if error_handler else None, ) + self.config = config async def aembed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]: """ @@ -258,6 +260,7 @@ def __init__( cache=model_cache, events=FNLLMEvents(error_handler) if error_handler else None, ) + self.config = config async def achat( self, prompt: str, history: list | None = None, **kwargs @@ -365,6 +368,7 @@ def __init__( cache=model_cache, events=FNLLMEvents(error_handler) if error_handler else None, ) + self.config = config async def aembed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]: """ diff --git a/graphrag/language_model/providers/fnllm/utils.py b/graphrag/language_model/providers/fnllm/utils.py index a493089160..f50b0250e2 100644 --- a/graphrag/language_model/providers/fnllm/utils.py +++ b/graphrag/language_model/providers/fnllm/utils.py @@ -54,12 +54,7 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon JsonStrategy.VALID if config.model_supports_json else JsonStrategy.LOOSE ) chat_parameters = OpenAIChatParameters( - frequency_penalty=config.frequency_penalty, - presence_penalty=config.presence_penalty, - top_p=config.top_p, - max_tokens=config.max_tokens, - n=config.n, - temperature=config.temperature, + **get_openai_model_parameters_from_config(config) ) if azure: @@ -130,3 +125,36 @@ def run_coroutine_sync(coroutine: Coroutine[Any, Any, T]) -> T: _thr.start() future = asyncio.run_coroutine_threadsafe(coroutine, _loop) return future.result() + + +def is_reasoning_model(model: str) -> bool: + """Return whether the model uses a known OpenAI reasoning model.""" + return model.lower() in {"o1", "o1-mini", "o3-mini"} + + +def get_openai_model_parameters_from_config( + config: LanguageModelConfig, +) -> dict[str, Any]: + """Get the model parameters for a given config, adjusting for reasoning API differences.""" + return get_openai_model_parameters_from_dict(config.model_dump()) + + +def get_openai_model_parameters_from_dict(config: dict[str, Any]) -> dict[str, Any]: + """Get the model parameters for a given config, adjusting for reasoning API differences.""" + params = { + "n": config.get("n"), + } + if is_reasoning_model(config["model"]): + params["max_completion_tokens"] = config.get("max_completion_tokens") + params["reasoning_effort"] = config.get("reasoning_effort") + else: + params["max_tokens"] = config.get("max_tokens") + params["temperature"] = config.get("temperature") + params["frequency_penalty"] = config.get("frequency_penalty") + params["presence_penalty"] = config.get("presence_penalty") + params["top_p"] = config.get("top_p") + + if config.get("response_format"): + params["response_format"] = config["response_format"] + + return params diff --git a/graphrag/prompts/index/community_report.py b/graphrag/prompts/index/community_report.py index 35ca38bc8b..c3a7702ba0 100644 --- a/graphrag/prompts/index/community_report.py +++ b/graphrag/prompts/index/community_report.py @@ -51,6 +51,7 @@ Do not include information where the supporting evidence for it is not provided. +Limit the total report length to {max_report_length} words. # Example Input ----------- @@ -147,4 +148,6 @@ Do not include information where the supporting evidence for it is not provided. +Limit the total report length to {max_report_length} words. + Output:""" diff --git a/graphrag/prompts/index/community_report_text_units.py b/graphrag/prompts/index/community_report_text_units.py index 966bab61b4..47fcd29c09 100644 --- a/graphrag/prompts/index/community_report_text_units.py +++ b/graphrag/prompts/index/community_report_text_units.py @@ -45,6 +45,8 @@ where 1, 2, 4, 5, 7, 23, 2, 34, and 46 represent the id (not the index) of the relevant data record. +Limit the total report length to {max_report_length} words. + # Example Input ----------- SOURCES diff --git a/graphrag/prompts/index/extract_claims.py b/graphrag/prompts/index/extract_claims.py index f784c02d07..5e0e5570c6 100644 --- a/graphrag/prompts/index/extract_claims.py +++ b/graphrag/prompts/index/extract_claims.py @@ -58,4 +58,4 @@ CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n" -LOOP_PROMPT = "It appears some entities may have still been missed. Answer Y or N if there are still entities that need to be added.\n" +LOOP_PROMPT = "It appears some entities may have still been missed. Answer Y if there are still entities that need to be added, or N if there are none. Please answer with a single letter Y or N.\n" diff --git a/graphrag/prompts/index/extract_graph.py b/graphrag/prompts/index/extract_graph.py index b1aaea3d3f..a94b36142e 100644 --- a/graphrag/prompts/index/extract_graph.py +++ b/graphrag/prompts/index/extract_graph.py @@ -126,4 +126,4 @@ Output:""" CONTINUE_PROMPT = "MANY entities and relationships were missed in the last extraction. Remember to ONLY emit entities that match any of the previously extracted types. Add them below using the same format:\n" -LOOP_PROMPT = "It appears some entities and relationships may have still been missed. Answer Y or N if there are still entities or relationships that need to be added.\n" +LOOP_PROMPT = "It appears some entities and relationships may have still been missed. Answer Y if there are still entities or relationships that need to be added, or N if there are none. Please answer with a single letter Y or N.\n" diff --git a/graphrag/prompts/index/summarize_descriptions.py b/graphrag/prompts/index/summarize_descriptions.py index 8e544999ad..4a916195bf 100644 --- a/graphrag/prompts/index/summarize_descriptions.py +++ b/graphrag/prompts/index/summarize_descriptions.py @@ -5,10 +5,11 @@ SUMMARIZE_PROMPT = """ You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. -Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Given one or more entities, and a list of descriptions, all related to the same entity or group of entities. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. Make sure it is written in third person, and include the entity names so we have the full context. +Limit the final description length to {max_length} words. ####### -Data- diff --git a/graphrag/prompts/query/global_search_map_system_prompt.py b/graphrag/prompts/query/global_search_map_system_prompt.py index db1a649df3..02e98f9daa 100644 --- a/graphrag/prompts/query/global_search_map_system_prompt.py +++ b/graphrag/prompts/query/global_search_map_system_prompt.py @@ -42,6 +42,7 @@ Do not include information where the supporting evidence for it is not provided. +Limit your response length to {max_length} words. ---Data tables--- @@ -72,6 +73,8 @@ Do not include information where the supporting evidence for it is not provided. +Limit your response length to {max_length} words. + The response should be JSON formatted as follows: {{ "points": [ diff --git a/graphrag/prompts/query/global_search_reduce_system_prompt.py b/graphrag/prompts/query/global_search_reduce_system_prompt.py index c9dbb9188d..01bf455237 100644 --- a/graphrag/prompts/query/global_search_reduce_system_prompt.py +++ b/graphrag/prompts/query/global_search_reduce_system_prompt.py @@ -35,6 +35,7 @@ Do not include information where the supporting evidence for it is not provided. +Limit your response length to {max_length} words. ---Target response length and format--- @@ -70,6 +71,7 @@ Do not include information where the supporting evidence for it is not provided. +Limit your response length to {max_length} words. ---Target response length and format--- diff --git a/graphrag/query/context_builder/community_context.py b/graphrag/query/context_builder/community_context.py index 88afe20ae9..ba506a0a9a 100644 --- a/graphrag/query/context_builder/community_context.py +++ b/graphrag/query/context_builder/community_context.py @@ -34,7 +34,7 @@ def build_community_context( include_community_weight: bool = True, community_weight_name: str = "occurrence weight", normalize_community_weight: bool = True, - max_tokens: int = 8000, + max_context_tokens: int = 8000, single_batch: bool = True, context_name: str = "Reports", random_state: int = 86, @@ -154,7 +154,7 @@ def _cut_batch() -> None: new_context_text, new_context = _report_context_text(report, attributes) new_tokens = num_tokens(new_context_text, token_encoder) - if batch_tokens + new_tokens > max_tokens: + if batch_tokens + new_tokens > max_context_tokens: # add the current batch to the context data and start a new batch if we are in multi-batch mode _cut_batch() if single_batch: diff --git a/graphrag/query/context_builder/conversation_history.py b/graphrag/query/context_builder/conversation_history.py index 33f516dbd4..3039db29d4 100644 --- a/graphrag/query/context_builder/conversation_history.py +++ b/graphrag/query/context_builder/conversation_history.py @@ -151,7 +151,7 @@ def build_context( token_encoder: tiktoken.Encoding | None = None, include_user_turns_only: bool = True, max_qa_turns: int | None = 5, - max_tokens: int = 8000, + max_context_tokens: int = 8000, recency_bias: bool = True, column_delimiter: str = "|", context_name: str = "Conversation History", @@ -202,7 +202,7 @@ def build_context( context_df = pd.DataFrame(turn_list) context_text = header + context_df.to_csv(sep=column_delimiter, index=False) - if num_tokens(context_text, token_encoder) > max_tokens: + if num_tokens(context_text, token_encoder) > max_context_tokens: break current_context_df = context_df diff --git a/graphrag/query/context_builder/dynamic_community_selection.py b/graphrag/query/context_builder/dynamic_community_selection.py index 932d271ce6..80fc562ae1 100644 --- a/graphrag/query/context_builder/dynamic_community_selection.py +++ b/graphrag/query/context_builder/dynamic_community_selection.py @@ -20,8 +20,6 @@ log = logging.getLogger(__name__) -DEFAULT_RATE_LLM_PARAMS = {"temperature": 0.0, "max_tokens": 2000} - class DynamicCommunitySelection: """Dynamic community selection to select community reports that are relevant to the query. @@ -42,7 +40,7 @@ def __init__( num_repeats: int = 1, max_level: int = 2, concurrent_coroutines: int = 8, - llm_kwargs: Any = DEFAULT_RATE_LLM_PARAMS, + model_params: dict[str, Any] | None = None, ): self.model = model self.token_encoder = token_encoder @@ -53,7 +51,7 @@ def __init__( self.keep_parent = keep_parent self.max_level = max_level self.semaphore = asyncio.Semaphore(concurrent_coroutines) - self.llm_kwargs = llm_kwargs + self.model_params = model_params if model_params else {} self.reports = {report.community_id: report for report in community_reports} self.communities = {community.short_id: community for community in communities} @@ -103,7 +101,7 @@ async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any rate_query=self.rate_query, num_repeats=self.num_repeats, semaphore=self.semaphore, - **self.llm_kwargs, + **self.model_params, ) for community in queue ]) diff --git a/graphrag/query/context_builder/local_context.py b/graphrag/query/context_builder/local_context.py index fca6259f0d..dcbda89a4e 100644 --- a/graphrag/query/context_builder/local_context.py +++ b/graphrag/query/context_builder/local_context.py @@ -30,7 +30,7 @@ def build_entity_context( selected_entities: list[Entity], token_encoder: tiktoken.Encoding | None = None, - max_tokens: int = 8000, + max_context_tokens: int = 8000, include_entity_rank: bool = True, rank_description: str = "number of relationships", column_delimiter: str = "|", @@ -72,7 +72,7 @@ def build_entity_context( new_context.append(field_value) new_context_text = column_delimiter.join(new_context) + "\n" new_tokens = num_tokens(new_context_text, token_encoder) - if current_tokens + new_tokens > max_tokens: + if current_tokens + new_tokens > max_context_tokens: break current_context_text += new_context_text all_context_records.append(new_context) @@ -92,7 +92,7 @@ def build_covariates_context( selected_entities: list[Entity], covariates: list[Covariate], token_encoder: tiktoken.Encoding | None = None, - max_tokens: int = 8000, + max_context_tokens: int = 8000, column_delimiter: str = "|", context_name: str = "Covariates", ) -> tuple[str, pd.DataFrame]: @@ -136,7 +136,7 @@ def build_covariates_context( new_context_text = column_delimiter.join(new_context) + "\n" new_tokens = num_tokens(new_context_text, token_encoder) - if current_tokens + new_tokens > max_tokens: + if current_tokens + new_tokens > max_context_tokens: break current_context_text += new_context_text all_context_records.append(new_context) @@ -157,7 +157,7 @@ def build_relationship_context( relationships: list[Relationship], token_encoder: tiktoken.Encoding | None = None, include_relationship_weight: bool = False, - max_tokens: int = 8000, + max_context_tokens: int = 8000, top_k_relationships: int = 10, relationship_ranking_attribute: str = "rank", column_delimiter: str = "|", @@ -209,7 +209,7 @@ def build_relationship_context( new_context.append(field_value) new_context_text = column_delimiter.join(new_context) + "\n" new_tokens = num_tokens(new_context_text, token_encoder) - if current_tokens + new_tokens > max_tokens: + if current_tokens + new_tokens > max_context_tokens: break current_context_text += new_context_text all_context_records.append(new_context) diff --git a/graphrag/query/context_builder/rate_relevancy.py b/graphrag/query/context_builder/rate_relevancy.py index f2357212f7..b9d494f2a4 100644 --- a/graphrag/query/context_builder/rate_relevancy.py +++ b/graphrag/query/context_builder/rate_relevancy.py @@ -26,7 +26,7 @@ async def rate_relevancy( rate_query: str = RATE_QUERY, num_repeats: int = 1, semaphore: asyncio.Semaphore | None = None, - **llm_kwargs: Any, + **model_params: Any, ) -> dict[str, Any]: """ Rate the relevancy between the query and description on a scale of 0 to 10. @@ -38,7 +38,7 @@ async def rate_relevancy( llm: LLM model to use for rating token_encoder: token encoder num_repeats: number of times to repeat the rating process for the same community (default: 1) - llm_kwargs: additional arguments to pass to the LLM model + model_params: additional arguments to pass to the LLM model semaphore: asyncio.Semaphore to limit the number of concurrent LLM calls (default: None) """ llm_calls, prompt_tokens, output_tokens, ratings = 0, 0, 0, [] @@ -51,7 +51,7 @@ async def rate_relevancy( for _ in range(num_repeats): async with semaphore if semaphore is not None else nullcontext(): model_response = await model.achat( - prompt=query, history=messages, model_parameters=llm_kwargs, json=True + prompt=query, history=messages, model_parameters=model_params, json=True ) response = model_response.output.content try: diff --git a/graphrag/query/context_builder/source_context.py b/graphrag/query/context_builder/source_context.py index 0fb140bd86..b29ee9c0e5 100644 --- a/graphrag/query/context_builder/source_context.py +++ b/graphrag/query/context_builder/source_context.py @@ -23,7 +23,7 @@ def build_text_unit_context( token_encoder: tiktoken.Encoding | None = None, column_delimiter: str = "|", shuffle_data: bool = True, - max_tokens: int = 8000, + max_context_tokens: int = 8000, context_name: str = "Sources", random_state: int = 86, ) -> tuple[str, dict[str, pd.DataFrame]]: @@ -62,7 +62,7 @@ def build_text_unit_context( new_context_text = column_delimiter.join(new_context) + "\n" new_tokens = num_tokens(new_context_text, token_encoder) - if current_tokens + new_tokens > max_tokens: + if current_tokens + new_tokens > max_context_tokens: break current_context_text += new_context_text diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index decc3f0c3d..76fa1f430f 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -14,6 +14,9 @@ from graphrag.data_model.relationship import Relationship from graphrag.data_model.text_unit import TextUnit from graphrag.language_model.manager import ModelManager +from graphrag.language_model.providers.fnllm.utils import ( + get_openai_model_parameters_from_config, +) from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.structured_search.basic_search.basic_context import ( BasicSearchContext, @@ -77,6 +80,8 @@ def get_local_search_engine( ls_config = config.local_search + model_params = get_openai_model_parameters_from_config(model_settings) + return LocalSearch( model=chat_model, system_prompt=system_prompt, @@ -92,12 +97,7 @@ def get_local_search_engine( token_encoder=token_encoder, ), token_encoder=token_encoder, - model_params={ - "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) - "temperature": ls_config.temperature, - "top_p": ls_config.top_p, - "n": ls_config.n, - }, + model_params=model_params, context_builder_params={ "text_unit_prop": ls_config.text_unit_prop, "community_prop": ls_config.community_prop, @@ -110,7 +110,7 @@ def get_local_search_engine( "include_community_rank": False, "return_candidate_context": False, "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids - "max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + "max_context_tokens": ls_config.max_context_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) }, response_type=response_type, callbacks=callbacks, @@ -130,7 +130,6 @@ def get_global_search_engine( callbacks: list[QueryCallbacks] | None = None, ) -> GlobalSearch: """Create a global search engine based on data + configuration.""" - # TODO: Global search should select model based on config?? model_settings = config.get_language_model_config( config.global_search.chat_model_id ) @@ -143,6 +142,8 @@ def get_global_search_engine( config=model_settings, ) + model_params = get_openai_model_parameters_from_config(model_settings) + # Here we get encoding based on specified encoding name token_encoder = tiktoken.get_encoding(model_settings.encoding_model) gs_config = config.global_search @@ -153,14 +154,14 @@ def get_global_search_engine( dynamic_community_selection_kwargs.update({ "model": model, - # And here we get encoding based on model - "token_encoder": tiktoken.encoding_for_model(model_settings.model), + "token_encoder": token_encoder, "keep_parent": gs_config.dynamic_search_keep_parent, "num_repeats": gs_config.dynamic_search_num_repeats, "use_summary": gs_config.dynamic_search_use_summary, - "concurrent_coroutines": gs_config.dynamic_search_concurrent_coroutines, + "concurrent_coroutines": model_settings.concurrent_requests, "threshold": gs_config.dynamic_search_threshold, "max_level": gs_config.dynamic_search_max_level, + "model_params": {**model_params}, }) return GlobalSearch( @@ -178,18 +179,10 @@ def get_global_search_engine( ), token_encoder=token_encoder, max_data_tokens=gs_config.data_max_tokens, - map_llm_params={ - "max_tokens": gs_config.map_max_tokens, - "temperature": gs_config.temperature, - "top_p": gs_config.top_p, - "n": gs_config.n, - }, - reduce_llm_params={ - "max_tokens": gs_config.reduce_max_tokens, - "temperature": gs_config.temperature, - "top_p": gs_config.top_p, - "n": gs_config.n, - }, + map_llm_params={**model_params}, + reduce_llm_params={**model_params}, + map_max_length=gs_config.map_max_length, + reduce_max_length=gs_config.reduce_max_length, allow_general_knowledge=False, json_mode=False, context_builder_params={ @@ -201,10 +194,10 @@ def get_global_search_engine( "include_community_weight": True, "community_weight_name": "occurrence weight", "normalize_community_weight": True, - "max_tokens": gs_config.max_tokens, + "max_context_tokens": gs_config.max_context_tokens, "context_name": "Reports", }, - concurrent_coroutines=gs_config.concurrency, + concurrent_coroutines=model_settings.concurrent_requests, response_type=response_type, callbacks=callbacks, ) @@ -243,6 +236,7 @@ def get_drift_search_engine( embedding_model_settings = config.get_language_model_config( config.drift_search.embedding_model_id ) + if embedding_model_settings.max_retries == -1: embedding_model_settings.max_retries = ( len(reports) + len(entities) + len(relationships) @@ -253,6 +247,7 @@ def get_drift_search_engine( model_type=embedding_model_settings.type, config=embedding_model_settings, ) + token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model) return DRIFTSearch( @@ -310,7 +305,9 @@ def get_basic_search_engine( token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model) - ls_config = config.basic_search + bs_config = config.basic_search + + model_params = get_openai_model_parameters_from_config(chat_model_settings) return BasicSearch( model=chat_model, @@ -322,19 +319,10 @@ def get_basic_search_engine( token_encoder=token_encoder, ), token_encoder=token_encoder, - model_params={ - "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) - "temperature": ls_config.temperature, - "top_p": ls_config.top_p, - "n": ls_config.n, - }, + model_params=model_params, context_builder_params={ - "text_unit_prop": ls_config.text_unit_prop, - "conversation_history_max_turns": ls_config.conversation_history_max_turns, - "conversation_history_user_turns_only": True, - "return_candidate_context": False, "embedding_vectorstore_key": "id", - "max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + "k": bs_config.k, }, callbacks=callbacks, ) diff --git a/graphrag/query/structured_search/basic_search/search.py b/graphrag/query/structured_search/basic_search/search.py index 242c4856b9..e2fb29c012 100644 --- a/graphrag/query/structured_search/basic_search/search.py +++ b/graphrag/query/structured_search/basic_search/search.py @@ -20,11 +20,6 @@ from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import BaseSearch, SearchResult -DEFAULT_LLM_PARAMS = { - "max_tokens": 1500, - "temperature": 0.0, -} - log = logging.getLogger(__name__) """ Implementation of a generic RAG algorithm (vector search on raw text chunks) @@ -42,7 +37,7 @@ def __init__( system_prompt: str | None = None, response_type: str = "multiple paragraphs", callbacks: list[QueryCallbacks] | None = None, - model_params: dict[str, Any] = DEFAULT_LLM_PARAMS, + model_params: dict[str, Any] | None = None, context_builder_params: dict | None = None, ): super().__init__( diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index 50a4d02050..66d86be88b 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -137,7 +137,6 @@ async def decompose_query( prompt = DRIFT_PRIMER_PROMPT.format( query=query, community_reports=community_reports ) - model_response = await self.chat_model.achat(prompt, json=True) response = model_response.output.content diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 14e12120cb..2099e3592e 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -13,6 +13,9 @@ from graphrag.callbacks.query_callbacks import QueryCallbacks from graphrag.language_model.protocol.base import ChatModel +from graphrag.language_model.providers.fnllm.utils import ( + get_openai_model_parameters_from_dict, +) from graphrag.query.context_builder.conversation_history import ConversationHistory from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.llm.text_utils import num_tokens @@ -80,14 +83,18 @@ def init_local_search(self) -> LocalSearch: "include_community_rank": False, "return_candidate_context": False, "embedding_vectorstore_key": EntityVectorStoreKey.ID, - "max_tokens": self.context_builder.config.local_search_max_data_tokens, + "max_context_tokens": self.context_builder.config.local_search_max_data_tokens, } - model_params = { + model_params = get_openai_model_parameters_from_dict({ + "model": self.model.config.model, "max_tokens": self.context_builder.config.local_search_llm_max_gen_tokens, "temperature": self.context_builder.config.local_search_temperature, + "n": self.context_builder.config.local_search_n, + "top_p": self.context_builder.config.local_search_top_p, + "max_completion_tokens": self.context_builder.config.local_search_llm_max_gen_completion_tokens, "response_format": {"type": "json_object"}, - } + }) return LocalSearch( model=self.model, @@ -262,14 +269,20 @@ async def search( for callback in self.callbacks: callback.on_reduce_response_start(response_state) + model_params = get_openai_model_parameters_from_dict({ + "model": self.model.config.model, + "max_tokens": self.context_builder.config.reduce_max_tokens, + "temperature": self.context_builder.config.reduce_temperature, + "max_completion_tokens": self.context_builder.config.reduce_max_completion_tokens, + }) + reduced_response = await self._reduce_response( responses=response_state, query=query, llm_calls=llm_calls, prompt_tokens=prompt_tokens, output_tokens=output_tokens, - max_tokens=self.context_builder.config.reduce_max_tokens, - temperature=self.context_builder.config.reduce_temperature, + model_params=model_params, ) for callback in self.callbacks: @@ -307,12 +320,18 @@ async def stream_search( for callback in self.callbacks: callback.on_reduce_response_start(result.response) + model_params = get_openai_model_parameters_from_dict({ + "model": self.model.config.model, + "max_tokens": self.context_builder.config.reduce_max_tokens, + "temperature": self.context_builder.config.reduce_temperature, + "max_completion_tokens": self.context_builder.config.reduce_max_completion_tokens, + }) + full_response = "" async for resp in self._reduce_response_streaming( responses=result.response, query=query, - max_tokens=self.context_builder.config.reduce_max_tokens, - temperature=self.context_builder.config.reduce_temperature, + model_params=model_params, ): full_response += resp yield resp @@ -384,7 +403,7 @@ async def _reduce_response_streaming( self, responses: str | dict[str, Any], query: str, - **llm_kwargs, + model_params: dict[str, Any], ) -> AsyncGenerator[str, None]: """Reduce the response to a single comprehensive response. @@ -394,8 +413,6 @@ async def _reduce_response_streaming( The responses to reduce. query : str The original query. - llm_kwargs : dict[str, Any] - Additional keyword arguments to pass to the LLM. Returns ------- @@ -424,7 +441,7 @@ async def _reduce_response_streaming( async for response in self.model.achat_stream( prompt=query, history=search_messages, - model_parameters=llm_kwargs, + model_parameters=model_params, ): for callback in self.callbacks: callback.on_llm_new_token(response) diff --git a/graphrag/query/structured_search/global_search/community_context.py b/graphrag/query/structured_search/global_search/community_context.py index 35a2d12a6c..56fa0b42b3 100644 --- a/graphrag/query/structured_search/global_search/community_context.py +++ b/graphrag/query/structured_search/global_search/community_context.py @@ -65,7 +65,7 @@ async def build_context( include_community_weight: bool = True, community_weight_name: str = "occurrence", normalize_community_weight: bool = True, - max_tokens: int = 8000, + max_context_tokens: int = 8000, context_name: str = "Reports", conversation_history_user_turns_only: bool = True, conversation_history_max_turns: int | None = 5, @@ -84,7 +84,7 @@ async def build_context( include_user_turns_only=conversation_history_user_turns_only, max_qa_turns=conversation_history_max_turns, column_delimiter=column_delimiter, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, recency_bias=False, ) if conversation_history_context != "": @@ -113,7 +113,7 @@ async def build_context( include_community_weight=include_community_weight, community_weight_name=community_weight_name, normalize_community_weight=normalize_community_weight, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, single_batch=False, context_name=context_name, random_state=self.random_state, diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index f2e82af899..b7f75a43ee 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -33,16 +33,6 @@ from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object from graphrag.query.structured_search.base import BaseSearch, SearchResult -DEFAULT_MAP_LLM_PARAMS = { - "max_tokens": 1000, - "temperature": 0.0, -} - -DEFAULT_REDUCE_LLM_PARAMS = { - "max_tokens": 2000, - "temperature": 0.0, -} - log = logging.getLogger(__name__) @@ -71,8 +61,10 @@ def __init__( json_mode: bool = True, callbacks: list[QueryCallbacks] | None = None, max_data_tokens: int = 8000, - map_llm_params: dict[str, Any] = DEFAULT_MAP_LLM_PARAMS, - reduce_llm_params: dict[str, Any] = DEFAULT_REDUCE_LLM_PARAMS, + map_llm_params: dict[str, Any] | None = None, + reduce_llm_params: dict[str, Any] | None = None, + map_max_length: int = 1000, + reduce_max_length: int = 2000, context_builder_params: dict[str, Any] | None = None, concurrent_coroutines: int = 32, ): @@ -92,13 +84,15 @@ def __init__( self.callbacks = callbacks or [] self.max_data_tokens = max_data_tokens - self.map_llm_params = map_llm_params - self.reduce_llm_params = reduce_llm_params + self.map_llm_params = map_llm_params if map_llm_params else {} + self.reduce_llm_params = reduce_llm_params if reduce_llm_params else {} if json_mode: self.map_llm_params["response_format"] = {"type": "json_object"} else: # remove response_format key if json_mode is False self.map_llm_params.pop("response_format", None) + self.map_max_length = map_max_length + self.reduce_max_length = reduce_max_length self.semaphore = asyncio.Semaphore(concurrent_coroutines) @@ -118,7 +112,10 @@ async def stream_search( map_responses = await asyncio.gather(*[ self._map_response_single_batch( - context_data=data, query=query, **self.map_llm_params + context_data=data, + query=query, + max_length=self.map_max_length, + **self.map_llm_params, ) for data in context_result.context_chunks ]) @@ -130,6 +127,7 @@ async def stream_search( async for response in self._stream_reduce_response( map_responses=map_responses, # type: ignore query=query, + max_length=self.reduce_max_length, model_parameters=self.reduce_llm_params, ): yield response @@ -166,7 +164,10 @@ async def search( map_responses = await asyncio.gather(*[ self._map_response_single_batch( - context_data=data, query=query, **self.map_llm_params + context_data=data, + query=query, + max_length=self.map_max_length, + **self.map_llm_params, ) for data in context_result.context_chunks ]) @@ -209,13 +210,16 @@ async def _map_response_single_batch( self, context_data: str, query: str, + max_length: int, **llm_kwargs, ) -> SearchResult: """Generate answer for a single chunk of community reports.""" start_time = time.time() search_prompt = "" try: - search_prompt = self.map_system_prompt.format(context_data=context_data) + search_prompt = self.map_system_prompt.format( + context_data=context_data, max_length=max_length + ) search_messages = [ {"role": "system", "content": search_prompt}, ] @@ -411,6 +415,7 @@ async def _stream_reduce_response( self, map_responses: list[SearchResult], query: str, + max_length: int, **llm_kwargs, ) -> AsyncGenerator[str, None]: # collect all key points into a single list to prepare for sorting @@ -469,7 +474,9 @@ async def _stream_reduce_response( text_data = "\n\n".join(data) search_prompt = self.reduce_system_prompt.format( - report_data=text_data, response_type=self.response_type + report_data=text_data, + response_type=self.response_type, + max_length=max_length, ) if self.allow_general_knowledge: search_prompt += "\n" + self.general_knowledge_inclusion_prompt diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index b5b4e5f5b9..8883d009e7 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -96,7 +96,7 @@ def build_context( exclude_entity_names: list[str] | None = None, conversation_history_max_turns: int | None = 5, conversation_history_user_turns_only: bool = True, - max_tokens: int = 8000, + max_context_tokens: int = 8000, text_unit_prop: float = 0.5, community_prop: float = 0.25, top_k_mapped_entities: int = 10, @@ -161,21 +161,21 @@ def build_context( include_user_turns_only=conversation_history_user_turns_only, max_qa_turns=conversation_history_max_turns, column_delimiter=column_delimiter, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, recency_bias=False, ) if conversation_history_context.strip() != "": final_context.append(conversation_history_context) final_context_data = conversation_history_context_data - max_tokens = max_tokens - num_tokens( + max_context_tokens = max_context_tokens - num_tokens( conversation_history_context, self.token_encoder ) # build community context - community_tokens = max(int(max_tokens * community_prop), 0) + community_tokens = max(int(max_context_tokens * community_prop), 0) community_context, community_context_data = self._build_community_context( selected_entities=selected_entities, - max_tokens=community_tokens, + max_context_tokens=community_tokens, use_community_summary=use_community_summary, column_delimiter=column_delimiter, include_community_rank=include_community_rank, @@ -189,10 +189,10 @@ def build_context( # build local (i.e. entity-relationship-covariate) context local_prop = 1 - community_prop - text_unit_prop - local_tokens = max(int(max_tokens * local_prop), 0) + local_tokens = max(int(max_context_tokens * local_prop), 0) local_context, local_context_data = self._build_local_context( selected_entities=selected_entities, - max_tokens=local_tokens, + max_context_tokens=local_tokens, include_entity_rank=include_entity_rank, rank_description=rank_description, include_relationship_weight=include_relationship_weight, @@ -205,10 +205,10 @@ def build_context( final_context.append(str(local_context)) final_context_data = {**final_context_data, **local_context_data} - text_unit_tokens = max(int(max_tokens * text_unit_prop), 0) + text_unit_tokens = max(int(max_context_tokens * text_unit_prop), 0) text_unit_context, text_unit_context_data = self._build_text_unit_context( selected_entities=selected_entities, - max_tokens=text_unit_tokens, + max_context_tokens=text_unit_tokens, return_candidate_context=return_candidate_context, ) @@ -224,7 +224,7 @@ def build_context( def _build_community_context( self, selected_entities: list[Entity], - max_tokens: int = 4000, + max_context_tokens: int = 4000, use_community_summary: bool = False, column_delimiter: str = "|", include_community_rank: bool = False, @@ -232,7 +232,7 @@ def _build_community_context( return_candidate_context: bool = False, context_name: str = "Reports", ) -> tuple[str, dict[str, pd.DataFrame]]: - """Add community data to the context window until it hits the max_tokens limit.""" + """Add community data to the context window until it hits the max_context_tokens limit.""" if len(selected_entities) == 0 or len(self.community_reports) == 0: return ("", {context_name.lower(): pd.DataFrame()}) @@ -270,7 +270,7 @@ def _build_community_context( shuffle_data=False, include_community_rank=include_community_rank, min_community_rank=min_community_rank, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, single_batch=True, context_name=context_name, ) @@ -306,12 +306,12 @@ def _build_community_context( def _build_text_unit_context( self, selected_entities: list[Entity], - max_tokens: int = 8000, + max_context_tokens: int = 8000, return_candidate_context: bool = False, column_delimiter: str = "|", context_name: str = "Sources", ) -> tuple[str, dict[str, pd.DataFrame]]: - """Rank matching text units and add them to the context window until it hits the max_tokens limit.""" + """Rank matching text units and add them to the context window until it hits the max_context_tokens limit.""" if not selected_entities or not self.text_units: return ("", {context_name.lower(): pd.DataFrame()}) selected_text_units = [] @@ -345,7 +345,7 @@ def _build_text_unit_context( context_text, context_data = build_text_unit_context( text_units=selected_text_units, token_encoder=self.token_encoder, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, shuffle_data=False, context_name=context_name, column_delimiter=column_delimiter, @@ -377,7 +377,7 @@ def _build_text_unit_context( def _build_local_context( self, selected_entities: list[Entity], - max_tokens: int = 8000, + max_context_tokens: int = 8000, include_entity_rank: bool = False, rank_description: str = "relationship count", include_relationship_weight: bool = False, @@ -391,7 +391,7 @@ def _build_local_context( entity_context, entity_context_data = build_entity_context( selected_entities=selected_entities, token_encoder=self.token_encoder, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, column_delimiter=column_delimiter, include_entity_rank=include_entity_rank, rank_description=rank_description, @@ -418,7 +418,7 @@ def _build_local_context( selected_entities=added_entities, relationships=list(self.relationships.values()), token_encoder=self.token_encoder, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, column_delimiter=column_delimiter, top_k_relationships=top_k_relationships, include_relationship_weight=include_relationship_weight, @@ -437,7 +437,7 @@ def _build_local_context( selected_entities=added_entities, covariates=self.covariates[covariate], token_encoder=self.token_encoder, - max_tokens=max_tokens, + max_context_tokens=max_context_tokens, column_delimiter=column_delimiter, context_name=covariate, ) @@ -445,7 +445,7 @@ def _build_local_context( current_context.append(covariate_context) current_context_data[covariate.lower()] = covariate_context_data - if total_tokens > max_tokens: + if total_tokens > max_context_tokens: log.info("Reached token limit - reverting to previous context state") break diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index ed55eb2876..3a02caaf44 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -22,11 +22,6 @@ from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import BaseSearch, SearchResult -DEFAULT_LLM_PARAMS = { - "max_tokens": 1500, - "temperature": 0.0, -} - log = logging.getLogger(__name__) @@ -41,7 +36,7 @@ def __init__( system_prompt: str | None = None, response_type: str = "multiple paragraphs", callbacks: list[QueryCallbacks] | None = None, - model_params: dict[str, Any] = DEFAULT_LLM_PARAMS, + model_params: dict[str, Any] | None = None, context_builder_params: dict | None = None, ): super().__init__( diff --git a/poetry.lock b/poetry.lock index 397330cb52..02b404a932 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiofiles" @@ -3987,13 +3987,13 @@ scipy = ">=1.0" [[package]] name = "pyparsing" -version = "3.2.1" +version = "3.2.3" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.9" files = [ - {file = "pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1"}, - {file = "pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a"}, + {file = "pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf"}, + {file = "pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be"}, ] [package.extras] @@ -4104,13 +4104,13 @@ six = ">=1.5" [[package]] name = "python-dotenv" -version = "1.0.1" +version = "1.1.0" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, - {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, + {file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"}, + {file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"}, ] [package.extras] @@ -4132,13 +4132,13 @@ dev = ["backports.zoneinfo", "black", "build", "freezegun", "mdx_truly_sane_list [[package]] name = "pytz" -version = "2025.1" +version = "2025.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" files = [ - {file = "pytz-2025.1-py2.py3-none-any.whl", hash = "sha256:89dd22dca55b46eac6eda23b2d72721bf1bdfef212645d81513ef5d03038de57"}, - {file = "pytz-2025.1.tar.gz", hash = "sha256:c2db42be2a2518b28e65f9207c4d05e6ff547d1efa4086469ef855e4ab70178e"}, + {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, + {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] [[package]] @@ -4852,13 +4852,13 @@ win32 = ["pywin32"] [[package]] name = "setuptools" -version = "77.0.3" +version = "78.0.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" files = [ - {file = "setuptools-77.0.3-py3-none-any.whl", hash = "sha256:67122e78221da5cf550ddd04cf8742c8fe12094483749a792d56cd669d6cf58c"}, - {file = "setuptools-77.0.3.tar.gz", hash = "sha256:583b361c8da8de57403743e756609670de6fb2345920e36dc5c2d914c319c945"}, + {file = "setuptools-78.0.2-py3-none-any.whl", hash = "sha256:4a612c80e1f1d71b80e4906ce730152e8dec23df439f82731d9d0b608d7b700d"}, + {file = "setuptools-78.0.2.tar.gz", hash = "sha256:137525e6afb9022f019d6e884a319017f9bf879a0d8783985d32cbc8683cab93"}, ] [package.extras] @@ -5289,42 +5289,42 @@ files = [ [[package]] name = "tiktoken" -version = "0.8.0" +version = "0.9.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" optional = false python-versions = ">=3.9" files = [ - {file = "tiktoken-0.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b07e33283463089c81ef1467180e3e00ab00d46c2c4bbcef0acab5f771d6695e"}, - {file = "tiktoken-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9269348cb650726f44dd3bbb3f9110ac19a8dcc8f54949ad3ef652ca22a38e21"}, - {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e13f37bc4ef2d012731e93e0fef21dc3b7aea5bb9009618de9a4026844e560"}, - {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f13d13c981511331eac0d01a59b5df7c0d4060a8be1e378672822213da51e0a2"}, - {file = "tiktoken-0.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6b2ddbc79a22621ce8b1166afa9f9a888a664a579350dc7c09346a3b5de837d9"}, - {file = "tiktoken-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d8c2d0e5ba6453a290b86cd65fc51fedf247e1ba170191715b049dac1f628005"}, - {file = "tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1"}, - {file = "tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a"}, - {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d"}, - {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47"}, - {file = "tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419"}, - {file = "tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99"}, - {file = "tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586"}, - {file = "tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b"}, - {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab"}, - {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04"}, - {file = "tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc"}, - {file = "tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db"}, - {file = "tiktoken-0.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:02be1666096aff7da6cbd7cdaa8e7917bfed3467cd64b38b1f112e96d3b06a24"}, - {file = "tiktoken-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94ff53c5c74b535b2cbf431d907fc13c678bbd009ee633a2aca269a04389f9a"}, - {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b231f5e8982c245ee3065cd84a4712d64692348bc609d84467c57b4b72dcbc5"}, - {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4177faa809bd55f699e88c96d9bb4635d22e3f59d635ba6fd9ffedf7150b9953"}, - {file = "tiktoken-0.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5376b6f8dc4753cd81ead935c5f518fa0fbe7e133d9e25f648d8c4dabdd4bad7"}, - {file = "tiktoken-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:18228d624807d66c87acd8f25fc135665617cab220671eb65b50f5d70fa51f69"}, - {file = "tiktoken-0.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7e17807445f0cf1f25771c9d86496bd8b5c376f7419912519699f3cc4dc5c12e"}, - {file = "tiktoken-0.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:886f80bd339578bbdba6ed6d0567a0d5c6cfe198d9e587ba6c447654c65b8edc"}, - {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6adc8323016d7758d6de7313527f755b0fc6c72985b7d9291be5d96d73ecd1e1"}, - {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b591fb2b30d6a72121a80be24ec7a0e9eb51c5500ddc7e4c2496516dd5e3816b"}, - {file = "tiktoken-0.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:845287b9798e476b4d762c3ebda5102be87ca26e5d2c9854002825d60cdb815d"}, - {file = "tiktoken-0.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:1473cfe584252dc3fa62adceb5b1c763c1874e04511b197da4e6de51d6ce5a02"}, - {file = "tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2"}, + {file = "tiktoken-0.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:586c16358138b96ea804c034b8acf3f5d3f0258bd2bc3b0227af4af5d622e382"}, + {file = "tiktoken-0.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d9c59ccc528c6c5dd51820b3474402f69d9a9e1d656226848ad68a8d5b2e5108"}, + {file = "tiktoken-0.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0968d5beeafbca2a72c595e8385a1a1f8af58feaebb02b227229b69ca5357fd"}, + {file = "tiktoken-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92a5fb085a6a3b7350b8fc838baf493317ca0e17bd95e8642f95fc69ecfed1de"}, + {file = "tiktoken-0.9.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:15a2752dea63d93b0332fb0ddb05dd909371ededa145fe6a3242f46724fa7990"}, + {file = "tiktoken-0.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:26113fec3bd7a352e4b33dbaf1bd8948de2507e30bd95a44e2b1156647bc01b4"}, + {file = "tiktoken-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f32cc56168eac4851109e9b5d327637f15fd662aa30dd79f964b7c39fbadd26e"}, + {file = "tiktoken-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45556bc41241e5294063508caf901bf92ba52d8ef9222023f83d2483a3055348"}, + {file = "tiktoken-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03935988a91d6d3216e2ec7c645afbb3d870b37bcb67ada1943ec48678e7ee33"}, + {file = "tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3d80aad8d2c6b9238fc1a5524542087c52b860b10cbf952429ffb714bc1136"}, + {file = "tiktoken-0.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b2a21133be05dc116b1d0372af051cd2c6aa1d2188250c9b553f9fa49301b336"}, + {file = "tiktoken-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:11a20e67fdf58b0e2dea7b8654a288e481bb4fc0289d3ad21291f8d0849915fb"}, + {file = "tiktoken-0.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e88f121c1c22b726649ce67c089b90ddda8b9662545a8aeb03cfef15967ddd03"}, + {file = "tiktoken-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a6600660f2f72369acb13a57fb3e212434ed38b045fd8cc6cdd74947b4b5d210"}, + {file = "tiktoken-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95e811743b5dfa74f4b227927ed86cbc57cad4df859cb3b643be797914e41794"}, + {file = "tiktoken-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99376e1370d59bcf6935c933cb9ba64adc29033b7e73f5f7569f3aad86552b22"}, + {file = "tiktoken-0.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:badb947c32739fb6ddde173e14885fb3de4d32ab9d8c591cbd013c22b4c31dd2"}, + {file = "tiktoken-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a62d7a25225bafed786a524c1b9f0910a1128f4232615bf3f8257a73aaa3b16"}, + {file = "tiktoken-0.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2b0e8e05a26eda1249e824156d537015480af7ae222ccb798e5234ae0285dbdb"}, + {file = "tiktoken-0.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:27d457f096f87685195eea0165a1807fae87b97b2161fe8c9b1df5bd74ca6f63"}, + {file = "tiktoken-0.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cf8ded49cddf825390e36dd1ad35cd49589e8161fdcb52aa25f0583e90a3e01"}, + {file = "tiktoken-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc156cb314119a8bb9748257a2eaebd5cc0753b6cb491d26694ed42fc7cb3139"}, + {file = "tiktoken-0.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:cd69372e8c9dd761f0ab873112aba55a0e3e506332dd9f7522ca466e817b1b7a"}, + {file = "tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95"}, + {file = "tiktoken-0.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c6386ca815e7d96ef5b4ac61e0048cd32ca5a92d5781255e13b31381d28667dc"}, + {file = "tiktoken-0.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:75f6d5db5bc2c6274b674ceab1615c1778e6416b14705827d19b40e6355f03e0"}, + {file = "tiktoken-0.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e15b16f61e6f4625a57a36496d28dd182a8a60ec20a534c5343ba3cafa156ac7"}, + {file = "tiktoken-0.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebcec91babf21297022882344c3f7d9eed855931466c3311b1ad6b64befb3df"}, + {file = "tiktoken-0.9.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e5fd49e7799579240f03913447c0cdfa1129625ebd5ac440787afc4345990427"}, + {file = "tiktoken-0.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:26242ca9dc8b58e875ff4ca078b9a94d2f0813e6a535dcd2205df5d49d927cc7"}, + {file = "tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d"}, ] [package.dependencies] @@ -5515,13 +5515,13 @@ files = [ [[package]] name = "tzdata" -version = "2025.1" +version = "2025.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" files = [ - {file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"}, - {file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"}, + {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, + {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, ] [[package]] @@ -5822,4 +5822,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2905ece63aba1f6f63ce2b18aae19d04aeacfe3758222e4b2695735831d1b180" +content-hash = "4b6e1757f36d2659776a5244d73ee9db60f42b09eb81902ae34202300e13d17e" diff --git a/pyproject.toml b/pyproject.toml index c10e35b2a5..c39b84926b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,9 +58,9 @@ aiofiles = "^24.1.0" # LLM fnllm = {extras = ["azure", "openai"], version = "0.2.3"} json-repair = "^0.30.3" -openai = "^1.57.0" +openai = "^1.68.0" nltk = "3.9.1" -tiktoken = "^0.8.0" +tiktoken = "^0.9.0" # Data-Science numpy = "^1.25.2" diff --git a/tests/integration/language_model/test_factory.py b/tests/integration/language_model/test_factory.py index 1eb9920ee2..e25e4e246a 100644 --- a/tests/integration/language_model/test_factory.py +++ b/tests/integration/language_model/test_factory.py @@ -20,6 +20,8 @@ async def test_create_custom_chat_model(): class CustomChatModel: + config: Any + def __init__(self, **kwargs): pass @@ -54,6 +56,8 @@ def chat_stream( async def test_create_custom_embedding_llm(): class CustomEmbeddingModel: + config: Any + def __init__(self, **kwargs): pass diff --git a/tests/mock_provider.py b/tests/mock_provider.py index 18b3d63343..d68fd762df 100644 --- a/tests/mock_provider.py +++ b/tests/mock_provider.py @@ -8,6 +8,7 @@ from pydantic import BaseModel +from graphrag.config.enums import ModelType from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.language_model.response.base import ( BaseModelOutput, @@ -28,6 +29,9 @@ def __init__( ): self.responses = config.responses if config and config.responses else responses self.response_index = 0 + self.config = config or LanguageModelConfig( + type=ModelType.MockChat, model="gpt-4o", api_key="mock" + ) async def achat( self, @@ -94,7 +98,9 @@ class MockEmbeddingLLM: """A mock embedding LLM provider.""" def __init__(self, **kwargs: Any): - pass + self.config = LanguageModelConfig( + type=ModelType.MockEmbedding, model="text-embedding-ada-002", api_key="mock" + ) def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]: """Generate an embedding for the input text.""" diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 6de8e97395..e079cf9512 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -73,6 +73,7 @@ def assert_language_model_configs( assert actual.encoding_model == expected.encoding_model assert actual.max_tokens == expected.max_tokens assert actual.temperature == expected.temperature + assert actual.max_completion_tokens == expected.max_completion_tokens assert actual.top_p == expected.top_p assert actual.n == expected.n assert actual.frequency_penalty == expected.frequency_penalty @@ -224,7 +225,6 @@ def assert_extract_graph_configs( assert actual.entity_types == expected.entity_types assert actual.max_gleanings == expected.max_gleanings assert actual.strategy == expected.strategy - assert actual.encoding_model == expected.encoding_model assert actual.model_id == expected.model_id @@ -291,7 +291,6 @@ def assert_extract_claims_configs( assert actual.description == expected.description assert actual.max_gleanings == expected.max_gleanings assert actual.strategy == expected.strategy - assert actual.encoding_model == expected.encoding_model assert actual.model_id == expected.model_id @@ -318,11 +317,7 @@ def assert_local_search_configs( ) assert actual.top_k_entities == expected.top_k_entities assert actual.top_k_relationships == expected.top_k_relationships - assert actual.temperature == expected.temperature - assert actual.top_p == expected.top_p - assert actual.n == expected.n - assert actual.max_tokens == expected.max_tokens - assert actual.llm_max_tokens == expected.llm_max_tokens + assert actual.max_context_tokens == expected.max_context_tokens def assert_global_search_configs( @@ -331,23 +326,14 @@ def assert_global_search_configs( assert actual.map_prompt == expected.map_prompt assert actual.reduce_prompt == expected.reduce_prompt assert actual.knowledge_prompt == expected.knowledge_prompt - assert actual.temperature == expected.temperature - assert actual.top_p == expected.top_p - assert actual.n == expected.n - assert actual.max_tokens == expected.max_tokens + assert actual.max_context_tokens == expected.max_context_tokens assert actual.data_max_tokens == expected.data_max_tokens - assert actual.map_max_tokens == expected.map_max_tokens - assert actual.reduce_max_tokens == expected.reduce_max_tokens - assert actual.concurrency == expected.concurrency - assert actual.dynamic_search_llm == expected.dynamic_search_llm + assert actual.map_max_length == expected.map_max_length + assert actual.reduce_max_length == expected.reduce_max_length assert actual.dynamic_search_threshold == expected.dynamic_search_threshold assert actual.dynamic_search_keep_parent == expected.dynamic_search_keep_parent assert actual.dynamic_search_num_repeats == expected.dynamic_search_num_repeats assert actual.dynamic_search_use_summary == expected.dynamic_search_use_summary - assert ( - actual.dynamic_search_concurrent_coroutines - == expected.dynamic_search_concurrent_coroutines - ) assert actual.dynamic_search_max_level == expected.dynamic_search_max_level @@ -356,10 +342,6 @@ def assert_drift_search_configs( ) -> None: assert actual.prompt == expected.prompt assert actual.reduce_prompt == expected.reduce_prompt - assert actual.temperature == expected.temperature - assert actual.top_p == expected.top_p - assert actual.n == expected.n - assert actual.max_tokens == expected.max_tokens assert actual.data_max_tokens == expected.data_max_tokens assert actual.reduce_max_tokens == expected.reduce_max_tokens assert actual.reduce_temperature == expected.reduce_temperature @@ -392,15 +374,7 @@ def assert_basic_search_configs( actual: BasicSearchConfig, expected: BasicSearchConfig ) -> None: assert actual.prompt == expected.prompt - assert actual.text_unit_prop == expected.text_unit_prop - assert ( - actual.conversation_history_max_turns == expected.conversation_history_max_turns - ) - assert actual.temperature == expected.temperature - assert actual.top_p == expected.top_p - assert actual.n == expected.n - assert actual.max_tokens == expected.max_tokens - assert actual.llm_max_tokens == expected.llm_max_tokens + assert actual.k == expected.k def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) -> None: diff --git a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py index f8fc3b2f64..c5911344b6 100644 --- a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py +++ b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py @@ -213,7 +213,7 @@ def test_sort_context(): def test_sort_context_max_tokens(): - ctx = sort_context(context, max_tokens=800) + ctx = sort_context(context, max_context_tokens=800) assert ctx is not None, "Context is none" num = num_tokens(ctx) assert num <= 800, f"num_tokens is not less than or equal to 800: {num}" diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index 1336355d83..618b843078 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -57,6 +57,8 @@ async def test_extract_graph(): config.summarize_descriptions.strategy = { "type": "graph_intelligence", "llm": summarize_llm_settings, + "max_input_tokens": 1000, + "max_summary_length": 100, } await run_workflow(config, context)