From 0d6f2bd5b2c58ba7f0cd5f80c1dfa68e221b862f Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 19 May 2025 19:38:33 -0600 Subject: [PATCH 1/6] Remove max retries. Update Typer args --- graphrag/api/prompt_tune.py | 9 - graphrag/cli/main.py | 572 ++++++++++-------- graphrag/config/init_content.py | 4 +- .../config/models/language_model_config.py | 44 ++ .../index/operations/embed_text/embed_text.py | 4 - .../extract_covariates/extract_covariates.py | 4 - .../operations/extract_graph/extract_graph.py | 4 - .../summarize_communities.py | 4 - .../summarize_descriptions.py | 4 - graphrag/index/validate_config.py | 7 +- graphrag/query/factory.py | 29 +- 11 files changed, 360 insertions(+), 325 deletions(-) diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index 92f6ba94f3..29a8172011 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -109,15 +109,6 @@ async def generate_indexing_prompts( logger.info("Retrieving language model configuration...") default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID) - # if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls - # to be made or fallback to a default value in the worst case - if default_llm_settings.max_retries < -1: - default_llm_settings.max_retries = min( - len(doc_list), language_model_defaults.max_retries - ) - msg = f"max_retries not set, using default value: {default_llm_settings.max_retries}" - logger.warning(msg) - logger.info("Creating language model...") llm = ModelManager().register_chat( name="prompt_tuning", diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index 610427871f..4a1a02e6cb 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -80,23 +80,23 @@ def completer(incomplete: str) -> list[str]: @app.command("init") def _initialize_cli( - root: Annotated[ - Path, - typer.Option( - help="The project root directory.", - dir_okay=True, - writable=True, - resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, writable=True, match_wildcard="*" - ), + root: Path = typer.Option( + Path(), + "--root", "-r", + help="The project root directory.", + dir_okay=True, + writable=True, + resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, writable=True, match_wildcard="*" ), - ], - force: Annotated[ - bool, - typer.Option(help="Force initialization even if the project already exists."), - ] = False, -): + ), + force: bool = typer.Option( + False, + "--force", "-f", + help="Force initialization even if the project already exists.", + ), +) -> None: """Generate a default configuration file.""" from graphrag.cli.initialize import initialize_project_at @@ -105,60 +105,81 @@ def _initialize_cli( @app.command("index") def _index_cli( - config: Annotated[ - Path | None, - typer.Option( - help="The configuration to use.", exists=True, file_okay=True, readable=True + config: Path | None = typer.Option( + None, + "--config", + "-c", + help="The configuration to use.", + exists=True, + file_okay=True, + readable=True, + ), + root: Path = typer.Option( + Path(), + "--root", + "-r", + help="The project root directory.", + exists=True, + dir_okay=True, + writable=True, + resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, writable=True, match_wildcard="*" ), - ] = None, - root: Annotated[ - Path, - typer.Option( - help="The project root directory.", - exists=True, - dir_okay=True, - writable=True, - resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, writable=True, match_wildcard="*" - ), + ), + method: IndexingMethod = typer.Option( + IndexingMethod.Standard.value, + "--method", + "-m", + help="The indexing method to use.", + ), + verbose: bool = typer.Option( + False, + "--verbose", + "-v", + help="Run the indexing pipeline with verbose logging", + ), + memprofile: bool = typer.Option( + False, + "--memprofile", + help="Run the indexing pipeline with memory profiling", + ), + logger: LoggerType = typer.Option( + LoggerType.RICH.value, + "--logger", + help="The progress logger to use.", + ), + dry_run: bool = typer.Option( + False, + "--dry-run", + help=( + "Run the indexing pipeline without executing any steps " + "to inspect and validate the configuration." ), - ] = Path(), # set default to current directory - method: Annotated[ - IndexingMethod, typer.Option(help="The indexing method to use.") - ] = IndexingMethod.Standard, - verbose: Annotated[ - bool, typer.Option(help="Run the indexing pipeline with verbose logging") - ] = False, - memprofile: Annotated[ - bool, typer.Option(help="Run the indexing pipeline with memory profiling") - ] = False, - logger: Annotated[ - LoggerType, typer.Option(help="The progress logger to use.") - ] = LoggerType.RICH, - dry_run: Annotated[ - bool, - typer.Option( - help="Run the indexing pipeline without executing any steps to inspect and validate the configuration." + ), + cache: bool = typer.Option( + True, + "--cache/--no-cache", + help="Use LLM cache.", + ), + skip_validation: bool = typer.Option( + False, + "--skip-validation", + help="Skip any preflight validation. Useful when running no LLM steps.", + ), + output: Path | None = typer.Option( + None, + "--output", + "-o", + help=( + "Indexing pipeline output directory. " + "Overrides output.base_dir in the configuration file." ), - ] = False, - cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True, - skip_validation: Annotated[ - bool, - typer.Option( - help="Skip any preflight validation. Useful when running no LLM steps." - ), - ] = False, - output: Annotated[ - Path | None, - typer.Option( - help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.", - dir_okay=True, - writable=True, - resolve_path=True, - ), - ] = None, -): + dir_okay=True, + writable=True, + resolve_path=True, + ), +) -> None: """Build a knowledge graph index.""" from graphrag.cli.index import index_cli @@ -178,51 +199,65 @@ def _index_cli( @app.command("update") def _update_cli( - config: Annotated[ - Path | None, - typer.Option( - help="The configuration to use.", exists=True, file_okay=True, readable=True - ), - ] = None, - root: Annotated[ - Path, - typer.Option( - help="The project root directory.", - exists=True, - dir_okay=True, - writable=True, - resolve_path=True, - ), - ] = Path(), # set default to current directory - method: Annotated[ - IndexingMethod, typer.Option(help="The indexing method to use.") - ] = IndexingMethod.Standard, - verbose: Annotated[ - bool, typer.Option(help="Run the indexing pipeline with verbose logging") - ] = False, - memprofile: Annotated[ - bool, typer.Option(help="Run the indexing pipeline with memory profiling") - ] = False, - logger: Annotated[ - LoggerType, typer.Option(help="The progress logger to use.") - ] = LoggerType.RICH, - cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True, - skip_validation: Annotated[ - bool, - typer.Option( - help="Skip any preflight validation. Useful when running no LLM steps." + config: Path | None = typer.Option( + None, + "--config", "-c", + help="The configuration to use.", + exists=True, + file_okay=True, + readable=True, + ), + root: Path = typer.Option( + Path(), + "--root", "-r", + help="The project root directory.", + exists=True, + dir_okay=True, + writable=True, + resolve_path=True, + ), + method: IndexingMethod = typer.Option( + IndexingMethod.Standard.value, + "--method", "-m", + help="The indexing method to use.", + ), + verbose: bool = typer.Option( + False, + "--verbose", "-v", + help="Run the indexing pipeline with verbose logging.", + ), + memprofile: bool = typer.Option( + False, + "--memprofile", + help="Run the indexing pipeline with memory profiling.", + ), + logger: LoggerType = typer.Option( + LoggerType.RICH.value, + "--logger", + help="The progress logger to use.", + ), + cache: bool = typer.Option( + True, + "--cache/--no-cache", + help="Use LLM cache.", + ), + skip_validation: bool = typer.Option( + False, + "--skip-validation", + help="Skip any preflight validation. Useful when running no LLM steps.", + ), + output: Path | None = typer.Option( + None, + "--output", "-o", + help=( + "Indexing pipeline output directory. " + "Overrides output.base_dir in the configuration file." ), - ] = False, - output: Annotated[ - Path | None, - typer.Option( - help="Indexing pipeline output directory. Overrides output.base_dir in the configuration file.", - dir_okay=True, - writable=True, - resolve_path=True, - ), - ] = None, -): + dir_okay=True, + writable=True, + resolve_path=True, + ), +) -> None: """ Update an existing knowledge graph index. @@ -245,104 +280,107 @@ def _update_cli( @app.command("prompt-tune") def _prompt_tune_cli( - root: Annotated[ - Path, - typer.Option( - help="The project root directory.", - exists=True, - dir_okay=True, - writable=True, - resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, writable=True, match_wildcard="*" - ), - ), - ] = Path(), # set default to current directory - config: Annotated[ - Path | None, - typer.Option( - help="The configuration to use.", - exists=True, - file_okay=True, - readable=True, - autocompletion=path_autocomplete( - file_okay=True, dir_okay=False, match_wildcard="*" - ), - ), - ] = None, - verbose: Annotated[ - bool, typer.Option(help="Run the prompt tuning pipeline with verbose logging") - ] = False, - logger: Annotated[ - LoggerType, typer.Option(help="The progress logger to use.") - ] = LoggerType.RICH, - domain: Annotated[ - str | None, - typer.Option( - help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, a domain will be inferred from the input data." - ), - ] = None, - selection_method: Annotated[ - DocSelectionType, typer.Option(help="The text chunk selection method.") - ] = DocSelectionType.RANDOM, - n_subset_max: Annotated[ - int, - typer.Option( - help="The number of text chunks to embed when --selection-method=auto." - ), - ] = N_SUBSET_MAX, - k: Annotated[ - int, - typer.Option( - help="The maximum number of documents to select from each centroid when --selection-method=auto." - ), - ] = K, - limit: Annotated[ - int, - typer.Option( - help="The number of documents to load when --selection-method={random,top}." - ), - ] = LIMIT, - max_tokens: Annotated[ - int, typer.Option(help="The max token count for prompt generation.") - ] = MAX_TOKEN_COUNT, - min_examples_required: Annotated[ - int, - typer.Option( - help="The minimum number of examples to generate/include in the entity extraction prompt." - ), - ] = 2, - chunk_size: Annotated[ - int, - typer.Option( - help="The size of each example text chunk. Overrides chunks.size in the configuration file." - ), - ] = graphrag_config_defaults.chunks.size, - overlap: Annotated[ - int, - typer.Option( - help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file" + root: Path = typer.Option( + Path(), + "--root", "-r", + help="The project root directory.", + exists=True, + dir_okay=True, + writable=True, + resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, writable=True, match_wildcard="*" ), - ] = graphrag_config_defaults.chunks.overlap, - language: Annotated[ - str | None, - typer.Option( - help="The primary language used for inputs and outputs in graphrag prompts." + ), + config: Path | None = typer.Option( + None, + "--config", "-c", + help="The configuration to use.", + exists=True, + file_okay=True, + readable=True, + autocompletion=path_autocomplete( + file_okay=True, dir_okay=False, match_wildcard="*" ), - ] = None, - discover_entity_types: Annotated[ - bool, typer.Option(help="Discover and extract unspecified entity types.") - ] = True, - output: Annotated[ - Path, - typer.Option( - help="The directory to save prompts to, relative to the project root directory.", - dir_okay=True, - writable=True, - resolve_path=True, + ), + verbose: bool = typer.Option( + False, + "--verbose", "-v", + help="Run the prompt tuning pipeline with verbose logging.", + ), + logger: LoggerType = typer.Option( + LoggerType.RICH.value, + "--logger", + help="The progress logger to use.", + ), + domain: str | None = typer.Option( + None, + "--domain", + help=( + "The domain your input data is related to. " + "For example 'space science', 'microbiology', 'environmental news'. " + "If not defined, a domain will be inferred from the input data." ), - ] = Path("prompts"), -): + ), + selection_method: DocSelectionType = typer.Option( + DocSelectionType.RANDOM.value, + "--selection-method", + help="The text chunk selection method.", + ), + n_subset_max: int = typer.Option( + N_SUBSET_MAX, + "--n-subset-max", + help="The number of text chunks to embed when --selection-method=auto.", + ), + k: int = typer.Option( + K, + "--k", + help="The maximum number of documents to select from each centroid when --selection-method=auto.", + ), + limit: int = typer.Option( + LIMIT, + "--limit", + help="The number of documents to load when --selection-method={random,top}.", + ), + max_tokens: int = typer.Option( + MAX_TOKEN_COUNT, + "--max-tokens", + help="The max token count for prompt generation.", + ), + min_examples_required: int = typer.Option( + 2, + "--min-examples-required", + help="The minimum number of examples to generate/include in the entity extraction prompt.", + ), + chunk_size: int = typer.Option( + graphrag_config_defaults.chunks.size, + "--chunk-size", + help="The size of each example text chunk. Overrides chunks.size in the configuration file.", + ), + overlap: int = typer.Option( + graphrag_config_defaults.chunks.overlap, + "--overlap", + help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.", + ), + language: str | None = typer.Option( + None, + "--language", + help="The primary language used for inputs and outputs in graphrag prompts.", + ), + discover_entity_types: bool = typer.Option( + True, + "--discover-entity-types/--no-discover-entity-types", + help="Discover and extract unspecified entity types.", + ), + output: Path = typer.Option( + Path("prompts"), + "--output", "-o", + help="The directory to save prompts to, relative to the project root directory.", + dir_okay=True, + writable=True, + resolve_path=True, + ), +) -> None: """Generate custom graphrag prompts with your own data (i.e. auto templating).""" import asyncio @@ -373,66 +411,78 @@ def _prompt_tune_cli( @app.command("query") def _query_cli( - method: Annotated[SearchMethod, typer.Option(help="The query algorithm to use.")], - query: Annotated[str, typer.Option(help="The query to execute.")], - config: Annotated[ - Path | None, - typer.Option( - help="The configuration to use.", - exists=True, - file_okay=True, - readable=True, - autocompletion=path_autocomplete( - file_okay=True, dir_okay=False, match_wildcard="*" - ), + method: SearchMethod = typer.Option( + ..., + "--method", "-m", + help="The query algorithm to use.", + ), + query: str = typer.Option( + ..., + "--query", "-q", + help="The query to execute.", + ), + config: Path | None = typer.Option( + None, + "--config", "-c", + help="The configuration to use.", + exists=True, + file_okay=True, + readable=True, + autocompletion=path_autocomplete( + file_okay=True, dir_okay=False, match_wildcard="*" ), - ] = None, - data: Annotated[ - Path | None, - typer.Option( - help="Indexing pipeline output directory (i.e. contains the parquet files).", - exists=True, - dir_okay=True, - readable=True, - resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, match_wildcard="*" - ), + ), + data: Path | None = typer.Option( + None, + "--data", "-d", + help="Index output directory (contains the parquet files).", + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, match_wildcard="*" ), - ] = None, - root: Annotated[ - Path, - typer.Option( - help="The project root directory.", - exists=True, - dir_okay=True, - writable=True, - resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, match_wildcard="*" - ), + ), + root: Path = typer.Option( + Path.cwd(), + "--root", "-r", + help="The project root directory.", + exists=True, + dir_okay=True, + writable=True, + resolve_path=True, + autocompletion=path_autocomplete( + file_okay=False, dir_okay=True, match_wildcard="*" ), - ] = Path(), # set default to current directory - community_level: Annotated[ - int, - typer.Option( - help="The community level in the Leiden community hierarchy from which to load community reports. Higher values represent reports from smaller communities." + ), + community_level: int = typer.Option( + 2, + "--community-level", + help=( + "Leiden hierarchy level from which to load community reports. " + "Higher values represent smaller communities." ), - ] = 2, - dynamic_community_selection: Annotated[ - bool, - typer.Option(help="Use global search with dynamic community selection."), - ] = False, - response_type: Annotated[ - str, - typer.Option( - help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report. Default: Multiple Paragraphs" + ), + dynamic_community_selection: bool = typer.Option( + False, + "--dynamic-community-selection/--no-dynamic-selection", + help="Use global search with dynamic community selection.", + ), + response_type: str = typer.Option( + "Multiple Paragraphs", + "--response-type", + help=( + "Free-form description of the desired response format " + "(e.g. 'Single Sentence', 'List of 3-7 Points', etc.)." ), - ] = "Multiple Paragraphs", - streaming: Annotated[ - bool, typer.Option(help="Print response in a streaming manner.") - ] = False, -): + ), + streaming: bool = typer.Option( + False, + "--streaming/--no-streaming", + help="Print the response in a streaming manner.", + ), +) -> None: """Query a knowledge graph index.""" from graphrag.cli.query import ( run_basic_search, diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index ba538639aa..08559ffbab 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -33,7 +33,7 @@ concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed async_mode: {language_model_defaults.async_mode.value} # or asyncio retry_strategy: native - max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response) + max_retries: {language_model_defaults.max_retries} tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting {defs.DEFAULT_EMBEDDING_MODEL_ID}: @@ -51,7 +51,7 @@ concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed async_mode: {language_model_defaults.async_mode.value} # or asyncio retry_strategy: native - max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response) + max_retries: {language_model_defaults.max_retries} tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting diff --git a/graphrag/config/models/language_model_config.py b/graphrag/config/models/language_model_config.py index 9b8ad9388e..7f2c39a50b 100644 --- a/graphrag/config/models/language_model_config.py +++ b/graphrag/config/models/language_model_config.py @@ -198,10 +198,38 @@ def _validate_deployment_name(self) -> None: description="The number of tokens per minute to use for the LLM service.", default=language_model_defaults.tokens_per_minute, ) + + def _validate_tokens_per_minute(self) -> None: + """Validate the tokens per minute. + + Raises + ------ + ValueError + If the tokens per minute is less than 0. + """ + # If the value is a number, check if it is less than 1 + if isinstance(self.tokens_per_minute, int) and self.tokens_per_minute < 1: + msg = f"Tokens per minute must be a non zero postive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}." + raise ValueError(msg) + requests_per_minute: int | Literal["auto"] | None = Field( description="The number of requests per minute to use for the LLM service.", default=language_model_defaults.requests_per_minute, ) + + def _validate_requests_per_minute(self) -> None: + """Validate the requests per minute. + + Raises + ------ + ValueError + If the requests per minute is less than 0. + """ + # If the value is a number, check if it is less than 1 + if isinstance(self.requests_per_minute, int) and self.requests_per_minute < 1: + msg = f"Requests per minute must be a non zero postive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}." + raise ValueError(msg) + retry_strategy: str = Field( description="The retry strategy to use for the LLM service.", default=language_model_defaults.retry_strategy, @@ -210,6 +238,19 @@ def _validate_deployment_name(self) -> None: description="The maximum number of retries to use for the LLM service.", default=language_model_defaults.max_retries, ) + + def _validate_max_retries(self) -> None: + """Validate the maximum retries. + + Raises + ------ + ValueError + If the maximum retries is less than 0. + """ + if self.max_retries < 1: + msg = f"Maximum retries must be greater than or equal to 1. Suggested value: {language_model_defaults.max_retries}." + raise ValueError(msg) + max_retry_wait: float = Field( description="The maximum retry wait to use for the LLM service.", default=language_model_defaults.max_retry_wait, @@ -279,6 +320,9 @@ def _validate_model(self): self._validate_type() self._validate_auth_type() self._validate_api_key() + self._validate_tokens_per_minute() + self._validate_requests_per_minute() + self._validate_max_retries() self._validate_azure_settings() self._validate_encoding_model() return self diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 935644b025..b0a9e55410 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -109,10 +109,6 @@ async def _text_embed_with_vector_store( strategy_exec = load_strategy(strategy_type) strategy_config = {**strategy} - # if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made - if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1: - strategy_config["llm"]["max_retries"] = len(input) - # Get vector-storage configuration insert_batch_size: int = ( vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE diff --git a/graphrag/index/operations/extract_covariates/extract_covariates.py b/graphrag/index/operations/extract_covariates/extract_covariates.py index 22e01d1f4f..7ae23e5bcf 100644 --- a/graphrag/index/operations/extract_covariates/extract_covariates.py +++ b/graphrag/index/operations/extract_covariates/extract_covariates.py @@ -50,10 +50,6 @@ async def extract_covariates( strategy = strategy or {} strategy_config = {**strategy} - # if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made - if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1: - strategy_config["llm"]["max_retries"] = len(input) - async def run_strategy(row): text = row[column] result = await run_extract_claims( diff --git a/graphrag/index/operations/extract_graph/extract_graph.py b/graphrag/index/operations/extract_graph/extract_graph.py index 98e7cbc9f9..dabfd1c005 100644 --- a/graphrag/index/operations/extract_graph/extract_graph.py +++ b/graphrag/index/operations/extract_graph/extract_graph.py @@ -45,10 +45,6 @@ async def extract_graph( ) strategy_config = {**strategy} - # if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made - if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1: - strategy_config["llm"]["max_retries"] = len(text_units) - num_started = 0 async def run_strategy(row): diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index 0afdbe9a73..e64c3e23f4 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -45,10 +45,6 @@ async def summarize_communities( strategy_exec = load_strategy(strategy["type"]) strategy_config = {**strategy} - # if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made - if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1: - strategy_config["llm"]["max_retries"] = len(nodes) - community_hierarchy = ( communities.explode("children") .rename({"children": "sub_community"}, axis=1) diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index 86ffb6dd6e..92b54cf9a0 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -36,10 +36,6 @@ async def summarize_descriptions( ) strategy_config = {**strategy} - # if max_retries is not set, inject a dynamically assigned value based on the maximum number of expected LLM calls to be made - if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1: - strategy_config["llm"]["max_retries"] = len(entities_df) + len(relationships_df) - async def get_summarized( nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore ): diff --git a/graphrag/index/validate_config.py b/graphrag/index/validate_config.py index 4d1cf7e5cd..483979f58b 100644 --- a/graphrag/index/validate_config.py +++ b/graphrag/index/validate_config.py @@ -18,9 +18,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> # Validate Chat LLM configs # TODO: Replace default_chat_model with a way to select the model default_llm_settings = parameters.get_language_model_config("default_chat_model") - # if max_retries is not set, set it to the default value - if default_llm_settings.max_retries == -1: - default_llm_settings.max_retries = language_model_defaults.max_retries + llm = ModelManager().register_chat( name="test-llm", model_type=default_llm_settings.type, @@ -40,8 +38,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> embedding_llm_settings = parameters.get_language_model_config( parameters.embed_text.model_id ) - if embedding_llm_settings.max_retries == -1: - embedding_llm_settings.max_retries = language_model_defaults.max_retries + embed_llm = ModelManager().register_embedding( name="test-embed-llm", model_type=embedding_llm_settings.type, diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 907c83cacf..0dd1d3e1a7 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -52,11 +52,6 @@ def get_local_search_engine( """Create a local search engine based on data + configuration.""" model_settings = config.get_language_model_config(config.local_search.chat_model_id) - if model_settings.max_retries == -1: - model_settings.max_retries = ( - len(reports) + len(entities) + len(relationships) + len(covariates) - ) - chat_model = ModelManager().get_or_create_chat_model( name="local_search_chat", model_type=model_settings.type, @@ -66,10 +61,7 @@ def get_local_search_engine( embedding_settings = config.get_language_model_config( config.local_search.embedding_model_id ) - if embedding_settings.max_retries == -1: - embedding_settings.max_retries = ( - len(reports) + len(entities) + len(relationships) - ) + embedding_model = ModelManager().get_or_create_embedding_model( name="local_search_embedding", model_type=embedding_settings.type, @@ -134,8 +126,6 @@ def get_global_search_engine( config.global_search.chat_model_id ) - if model_settings.max_retries == -1: - model_settings.max_retries = len(reports) + len(entities) model = ModelManager().get_or_create_chat_model( name="global_search", model_type=model_settings.type, @@ -220,13 +210,6 @@ def get_drift_search_engine( config.drift_search.chat_model_id ) - if chat_model_settings.max_retries == -1: - chat_model_settings.max_retries = ( - config.drift_search.drift_k_followups - * config.drift_search.primer_folds - * config.drift_search.n_depth - ) - chat_model = ModelManager().get_or_create_chat_model( name="drift_search_chat", model_type=chat_model_settings.type, @@ -237,11 +220,6 @@ def get_drift_search_engine( config.drift_search.embedding_model_id ) - if embedding_model_settings.max_retries == -1: - embedding_model_settings.max_retries = ( - len(reports) + len(entities) + len(relationships) - ) - embedding_model = ModelManager().get_or_create_embedding_model( name="drift_search_embedding", model_type=embedding_model_settings.type, @@ -283,9 +261,6 @@ def get_basic_search_engine( config.basic_search.chat_model_id ) - if chat_model_settings.max_retries == -1: - chat_model_settings.max_retries = len(text_units) - chat_model = ModelManager().get_or_create_chat_model( name="basic_search_chat", model_type=chat_model_settings.type, @@ -295,8 +270,6 @@ def get_basic_search_engine( embedding_model_settings = config.get_language_model_config( config.basic_search.embedding_model_id ) - if embedding_model_settings.max_retries == -1: - embedding_model_settings.max_retries = len(text_units) embedding_model = ModelManager().get_or_create_embedding_model( name="basic_search_embedding", From ff52f4a28d0c1fb64751feb3446da86946e95e3d Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 19 May 2025 19:38:52 -0600 Subject: [PATCH 2/6] Format --- graphrag/cli/main.py | 48 +++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index 4a1a02e6cb..a9efc67bb4 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -82,7 +82,8 @@ def completer(incomplete: str) -> list[str]: def _initialize_cli( root: Path = typer.Option( Path(), - "--root", "-r", + "--root", + "-r", help="The project root directory.", dir_okay=True, writable=True, @@ -93,7 +94,8 @@ def _initialize_cli( ), force: bool = typer.Option( False, - "--force", "-f", + "--force", + "-f", help="Force initialization even if the project already exists.", ), ) -> None: @@ -201,7 +203,8 @@ def _index_cli( def _update_cli( config: Path | None = typer.Option( None, - "--config", "-c", + "--config", + "-c", help="The configuration to use.", exists=True, file_okay=True, @@ -209,7 +212,8 @@ def _update_cli( ), root: Path = typer.Option( Path(), - "--root", "-r", + "--root", + "-r", help="The project root directory.", exists=True, dir_okay=True, @@ -218,12 +222,14 @@ def _update_cli( ), method: IndexingMethod = typer.Option( IndexingMethod.Standard.value, - "--method", "-m", + "--method", + "-m", help="The indexing method to use.", ), verbose: bool = typer.Option( False, - "--verbose", "-v", + "--verbose", + "-v", help="Run the indexing pipeline with verbose logging.", ), memprofile: bool = typer.Option( @@ -248,7 +254,8 @@ def _update_cli( ), output: Path | None = typer.Option( None, - "--output", "-o", + "--output", + "-o", help=( "Indexing pipeline output directory. " "Overrides output.base_dir in the configuration file." @@ -282,7 +289,8 @@ def _update_cli( def _prompt_tune_cli( root: Path = typer.Option( Path(), - "--root", "-r", + "--root", + "-r", help="The project root directory.", exists=True, dir_okay=True, @@ -294,7 +302,8 @@ def _prompt_tune_cli( ), config: Path | None = typer.Option( None, - "--config", "-c", + "--config", + "-c", help="The configuration to use.", exists=True, file_okay=True, @@ -305,7 +314,8 @@ def _prompt_tune_cli( ), verbose: bool = typer.Option( False, - "--verbose", "-v", + "--verbose", + "-v", help="Run the prompt tuning pipeline with verbose logging.", ), logger: LoggerType = typer.Option( @@ -374,7 +384,8 @@ def _prompt_tune_cli( ), output: Path = typer.Option( Path("prompts"), - "--output", "-o", + "--output", + "-o", help="The directory to save prompts to, relative to the project root directory.", dir_okay=True, writable=True, @@ -413,17 +424,20 @@ def _prompt_tune_cli( def _query_cli( method: SearchMethod = typer.Option( ..., - "--method", "-m", + "--method", + "-m", help="The query algorithm to use.", ), query: str = typer.Option( ..., - "--query", "-q", + "--query", + "-q", help="The query to execute.", ), config: Path | None = typer.Option( None, - "--config", "-c", + "--config", + "-c", help="The configuration to use.", exists=True, file_okay=True, @@ -434,7 +448,8 @@ def _query_cli( ), data: Path | None = typer.Option( None, - "--data", "-d", + "--data", + "-d", help="Index output directory (contains the parquet files).", exists=True, dir_okay=True, @@ -446,7 +461,8 @@ def _query_cli( ), root: Path = typer.Option( Path.cwd(), - "--root", "-r", + "--root", + "-r", help="The project root directory.", exists=True, dir_okay=True, From 0d52b4b7e6da09b0d1d5ac9b26755a15e58aedad Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 19 May 2025 19:40:11 -0600 Subject: [PATCH 3/6] Semver --- .semversioner/next-release/minor-20250520014004743256.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .semversioner/next-release/minor-20250520014004743256.json diff --git a/.semversioner/next-release/minor-20250520014004743256.json b/.semversioner/next-release/minor-20250520014004743256.json new file mode 100644 index 0000000000..d95f02f677 --- /dev/null +++ b/.semversioner/next-release/minor-20250520014004743256.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Remove Dynamic Max Retries support. Refactor typer typing in cli interface" +} From fdd00d96020cc02b6ac04501d50240134539ef51 Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 19 May 2025 19:44:27 -0600 Subject: [PATCH 4/6] Fix typo --- graphrag/config/models/language_model_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphrag/config/models/language_model_config.py b/graphrag/config/models/language_model_config.py index 7f2c39a50b..524703dfde 100644 --- a/graphrag/config/models/language_model_config.py +++ b/graphrag/config/models/language_model_config.py @@ -209,7 +209,7 @@ def _validate_tokens_per_minute(self) -> None: """ # If the value is a number, check if it is less than 1 if isinstance(self.tokens_per_minute, int) and self.tokens_per_minute < 1: - msg = f"Tokens per minute must be a non zero postive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}." + msg = f"Tokens per minute must be a non zero positve number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}." raise ValueError(msg) requests_per_minute: int | Literal["auto"] | None = Field( @@ -227,7 +227,7 @@ def _validate_requests_per_minute(self) -> None: """ # If the value is a number, check if it is less than 1 if isinstance(self.requests_per_minute, int) and self.requests_per_minute < 1: - msg = f"Requests per minute must be a non zero postive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}." + msg = f"Requests per minute must be a non zero positve number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}." raise ValueError(msg) retry_strategy: str = Field( From 1aa823aa0459a9791072e0448b08f7440c9152c4 Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 19 May 2025 20:10:05 -0600 Subject: [PATCH 5/6] Ruff and Typos --- graphrag/api/prompt_tune.py | 2 +- graphrag/cli/main.py | 48 ++++++++++++++++--------------- graphrag/index/validate_config.py | 1 - pyproject.toml | 1 + 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index 29a8172011..c6dbb81df9 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -17,7 +17,7 @@ from pydantic import PositiveInt, validate_call from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks -from graphrag.config.defaults import graphrag_config_defaults, language_model_defaults +from graphrag.config.defaults import graphrag_config_defaults from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.language_model.manager import ModelManager from graphrag.logger.base import ProgressLogger diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index a9efc67bb4..956505e8b1 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -7,7 +7,6 @@ import re from collections.abc import Callable from pathlib import Path -from typing import Annotated import typer @@ -78,6 +77,20 @@ def completer(incomplete: str) -> list[str]: return completer +CONFIG_AUTOCOMPLETE = path_autocomplete( + file_okay=True, + dir_okay=False, + match_wildcard="*.yaml", + readable=True, +) + +ROOT_AUTOCOMPLETE = path_autocomplete( + file_okay=False, + dir_okay=True, + writable=True, + match_wildcard="*", +) + @app.command("init") def _initialize_cli( root: Path = typer.Option( @@ -88,9 +101,7 @@ def _initialize_cli( dir_okay=True, writable=True, resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, writable=True, match_wildcard="*" - ), + autocompletion=ROOT_AUTOCOMPLETE, ), force: bool = typer.Option( False, @@ -115,6 +126,7 @@ def _index_cli( exists=True, file_okay=True, readable=True, + autocompletion=CONFIG_AUTOCOMPLETE, ), root: Path = typer.Option( Path(), @@ -125,9 +137,7 @@ def _index_cli( dir_okay=True, writable=True, resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, writable=True, match_wildcard="*" - ), + autocompletion=ROOT_AUTOCOMPLETE, ), method: IndexingMethod = typer.Option( IndexingMethod.Standard.value, @@ -209,6 +219,7 @@ def _update_cli( exists=True, file_okay=True, readable=True, + autocompletion=CONFIG_AUTOCOMPLETE, ), root: Path = typer.Option( Path(), @@ -219,6 +230,7 @@ def _update_cli( dir_okay=True, writable=True, resolve_path=True, + autocompletion=ROOT_AUTOCOMPLETE, ), method: IndexingMethod = typer.Option( IndexingMethod.Standard.value, @@ -296,9 +308,7 @@ def _prompt_tune_cli( dir_okay=True, writable=True, resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, writable=True, match_wildcard="*" - ), + autocompletion=ROOT_AUTOCOMPLETE, ), config: Path | None = typer.Option( None, @@ -308,9 +318,7 @@ def _prompt_tune_cli( exists=True, file_okay=True, readable=True, - autocompletion=path_autocomplete( - file_okay=True, dir_okay=False, match_wildcard="*" - ), + autocompletion=CONFIG_AUTOCOMPLETE, ), verbose: bool = typer.Option( False, @@ -442,9 +450,7 @@ def _query_cli( exists=True, file_okay=True, readable=True, - autocompletion=path_autocomplete( - file_okay=True, dir_okay=False, match_wildcard="*" - ), + autocompletion=CONFIG_AUTOCOMPLETE, ), data: Path | None = typer.Option( None, @@ -455,12 +461,10 @@ def _query_cli( dir_okay=True, readable=True, resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, match_wildcard="*" - ), + autocompletion=ROOT_AUTOCOMPLETE, ), root: Path = typer.Option( - Path.cwd(), + Path(), "--root", "-r", help="The project root directory.", @@ -468,9 +472,7 @@ def _query_cli( dir_okay=True, writable=True, resolve_path=True, - autocompletion=path_autocomplete( - file_okay=False, dir_okay=True, match_wildcard="*" - ), + autocompletion=ROOT_AUTOCOMPLETE, ), community_level: int = typer.Option( 2, diff --git a/graphrag/index/validate_config.py b/graphrag/index/validate_config.py index 483979f58b..fc75494522 100644 --- a/graphrag/index/validate_config.py +++ b/graphrag/index/validate_config.py @@ -7,7 +7,6 @@ import sys from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks -from graphrag.config.defaults import language_model_defaults from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.language_model.manager import ModelManager from graphrag.logger.print_progress import ProgressLogger diff --git a/pyproject.toml b/pyproject.toml index f8c055fef4..d3e6580c7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,6 +245,7 @@ ignore = [ # TODO RE-Enable when we get bandwidth "PERF203", # Needs restructuring of errors, we should bail-out on first error "C901", # needs refactoring to remove cyclomatic complexity + "B008", # Needs to restructure our cli params with Typer into constants ] [tool.ruff.lint.per-file-ignores] From 1bf7bcc974ba7992bd35f37652840d06ff458c9f Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 19 May 2025 20:14:51 -0600 Subject: [PATCH 6/6] Format --- graphrag/cli/main.py | 1 + graphrag/config/models/language_model_config.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index 956505e8b1..7f0660d556 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -91,6 +91,7 @@ def completer(incomplete: str) -> list[str]: match_wildcard="*", ) + @app.command("init") def _initialize_cli( root: Path = typer.Option( diff --git a/graphrag/config/models/language_model_config.py b/graphrag/config/models/language_model_config.py index 524703dfde..ddede1f855 100644 --- a/graphrag/config/models/language_model_config.py +++ b/graphrag/config/models/language_model_config.py @@ -209,7 +209,7 @@ def _validate_tokens_per_minute(self) -> None: """ # If the value is a number, check if it is less than 1 if isinstance(self.tokens_per_minute, int) and self.tokens_per_minute < 1: - msg = f"Tokens per minute must be a non zero positve number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}." + msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}." raise ValueError(msg) requests_per_minute: int | Literal["auto"] | None = Field( @@ -227,7 +227,7 @@ def _validate_requests_per_minute(self) -> None: """ # If the value is a number, check if it is less than 1 if isinstance(self.requests_per_minute, int) and self.requests_per_minute < 1: - msg = f"Requests per minute must be a non zero positve number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}." + msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}." raise ValueError(msg) retry_strategy: str = Field(