diff --git a/.semversioner/next-release/patch-20250430211223127781.json b/.semversioner/next-release/patch-20250430211223127781.json new file mode 100644 index 0000000000..6df99f8240 --- /dev/null +++ b/.semversioner/next-release/patch-20250430211223127781.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Update as workflows" +} diff --git a/graphrag/api/index.py b/graphrag/api/index.py index 262ffa4d7f..f530bfa4b3 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -65,7 +65,7 @@ async def build_index( if memory_profile: log.warning("New pipeline does not yet support memory profiling.") - pipeline = PipelineFactory.create_pipeline(config, method) + pipeline = PipelineFactory.create_pipeline(config, method, is_update_run) workflow_callbacks.pipeline_start(pipeline.names()) diff --git a/graphrag/index/run/run_pipeline.py b/graphrag/index/run/run_pipeline.py index 26b965f1b5..65c41b7e64 100644 --- a/graphrag/index/run/run_pipeline.py +++ b/graphrag/index/run/run_pipeline.py @@ -13,8 +13,6 @@ import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.input.factory import create_input @@ -22,10 +20,7 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.index.update.incremental_index import ( - get_delta_docs, - update_dataframe_outputs, -) +from graphrag.index.update.incremental_index import get_delta_docs from graphrag.logger.base import ProgressLogger from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import PipelineStorage @@ -50,6 +45,10 @@ async def run_pipeline( dataset = await create_input(config.input, logger, root_dir) + # load existing state in case any workflows are stateful + state_json = await storage.get("context.json") + state = json.loads(state_json) if state_json else {} + if is_update_run: logger.info("Running incremental indexing.") @@ -62,48 +61,45 @@ async def run_pipeline( else: update_storage = create_storage_from_config(config.update_index_output) # we use this to store the new subset index, and will merge its content with the previous index - timestamped_storage = update_storage.child(time.strftime("%Y%m%d-%H%M%S")) + update_timestamp = time.strftime("%Y%m%d-%H%M%S") + timestamped_storage = update_storage.child(update_timestamp) delta_storage = timestamped_storage.child("delta") # copy the previous output to a backup folder, so we can replace it with the update # we'll read from this later when we merge the old and new indexes previous_storage = timestamped_storage.child("previous") await _copy_previous_output(storage, previous_storage) + state["update_timestamp"] = update_timestamp + + context = create_run_context( + storage=delta_storage, cache=cache, callbacks=callbacks, state=state + ) + # Run the pipeline on the new documents async for table in _run_pipeline( pipeline=pipeline, config=config, dataset=delta_dataset.new_inputs, - cache=cache, - storage=delta_storage, - callbacks=callbacks, logger=logger, + context=context, ): yield table logger.success("Finished running workflows on new documents.") - await update_dataframe_outputs( - previous_storage=previous_storage, - delta_storage=delta_storage, - output_storage=storage, - config=config, - cache=cache, - callbacks=NoopWorkflowCallbacks(), - progress_logger=logger, - ) - else: logger.info("Running standard indexing.") + context = create_run_context( + storage=storage, cache=cache, callbacks=callbacks, state=state + ) + async for table in _run_pipeline( pipeline=pipeline, config=config, dataset=dataset, - cache=cache, - storage=storage, - callbacks=callbacks, logger=logger, + context=context, ): yield table @@ -112,21 +108,11 @@ async def _run_pipeline( pipeline: Pipeline, config: GraphRagConfig, dataset: pd.DataFrame, - cache: PipelineCache, - storage: PipelineStorage, - callbacks: WorkflowCallbacks, logger: ProgressLogger, + context: PipelineRunContext, ) -> AsyncIterable[PipelineRunResult]: start_time = time.time() - # load existing state in case any workflows are stateful - state_json = await storage.get("context.json") - state = json.loads(state_json) if state_json else {} - - context = create_run_context( - storage=storage, cache=cache, callbacks=callbacks, state=state - ) - log.info("Final # of rows loaded: %s", len(dataset)) context.stats.num_documents = len(dataset) last_workflow = "starting documents" @@ -138,11 +124,11 @@ async def _run_pipeline( for name, workflow_function in pipeline.run(): last_workflow = name progress = logger.child(name, transient=False) - callbacks.workflow_start(name, None) + context.callbacks.workflow_start(name, None) work_time = time.time() result = await workflow_function(config, context) progress(Progress(percent=1)) - callbacks.workflow_end(name, result) + context.callbacks.workflow_end(name, result) yield PipelineRunResult( workflow=name, result=result.result, state=context.state, errors=None ) @@ -154,7 +140,7 @@ async def _run_pipeline( except Exception as e: log.exception("error running workflow %s", last_workflow) - callbacks.error("Error running pipeline!", e, traceback.format_exc()) + context.callbacks.error("Error running pipeline!", e, traceback.format_exc()) yield PipelineRunResult( workflow=last_workflow, result=None, state=context.state, errors=[e] ) diff --git a/graphrag/index/run/utils.py b/graphrag/index/run/utils.py index 79ab963123..a5c2307439 100644 --- a/graphrag/index/run/utils.py +++ b/graphrag/index/run/utils.py @@ -9,12 +9,14 @@ from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager +from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats from graphrag.logger.base import ProgressLogger from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.api import create_storage_from_config def create_run_context( @@ -44,3 +46,16 @@ def create_callback_chain( if progress is not None: manager.register(ProgressWorkflowCallbacks(progress)) return manager + + +def get_update_storages( + config: GraphRagConfig, timestamp: str +) -> tuple[PipelineStorage, PipelineStorage, PipelineStorage]: + """Get storage objects for the update index run.""" + output_storage = create_storage_from_config(config.output) + update_storage = create_storage_from_config(config.update_index_output) + timestamped_storage = update_storage.child(timestamp) + delta_storage = timestamped_storage.child("delta") + previous_storage = timestamped_storage.child("previous") + + return output_storage, previous_storage, delta_storage diff --git a/graphrag/index/update/incremental_index.py b/graphrag/index/update/incremental_index.py index 4d3bfcef1b..ac56e30df4 100644 --- a/graphrag/index/update/incremental_index.py +++ b/graphrag/index/update/incremental_index.py @@ -8,25 +8,9 @@ import numpy as np import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings -from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.update.communities import ( - _update_and_merge_communities, - _update_and_merge_community_reports, -) -from graphrag.index.update.entities import ( - _group_and_resolve_entities, -) -from graphrag.index.update.relationships import _update_and_merge_relationships -from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships -from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings -from graphrag.logger.print_progress import ProgressLogger from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import ( load_table_from_storage, - storage_has_table, write_table_to_storage, ) @@ -79,216 +63,7 @@ async def get_delta_docs( return InputDelta(new_docs, deleted_docs) -async def update_dataframe_outputs( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, - config: GraphRagConfig, - cache: PipelineCache, - callbacks: WorkflowCallbacks, - progress_logger: ProgressLogger, -) -> None: - """Update the mergeable outputs. - - Parameters - ---------- - previous_storage : PipelineStorage - The storage used to store the dataframes in the original run. - delta_storage : PipelineStorage - The storage used to store the subset of new dataframes in the update run. - output_storage : PipelineStorage - The storage used to store the updated dataframes (the final incremental output). - """ - progress_logger.info("Updating Documents") - final_documents_df = await _concat_dataframes( - "documents", previous_storage, delta_storage, output_storage - ) - - # Update entities, relationships and merge them - progress_logger.info("Updating Entities and Relationships") - ( - merged_entities_df, - merged_relationships_df, - entity_id_mapping, - ) = await _update_entities_and_relationships( - previous_storage, delta_storage, output_storage, config, cache, callbacks - ) - - # Update and merge final text units - progress_logger.info("Updating Text Units") - merged_text_units = await _update_text_units( - previous_storage, delta_storage, output_storage, entity_id_mapping - ) - - # Merge final covariates - if await storage_has_table( - "covariates", previous_storage - ) and await storage_has_table("covariates", delta_storage): - progress_logger.info("Updating Covariates") - await _update_covariates(previous_storage, delta_storage, output_storage) - - # Merge final communities - progress_logger.info("Updating Communities") - community_id_mapping = await _update_communities( - previous_storage, delta_storage, output_storage - ) - - # Merge community reports - progress_logger.info("Updating Community Reports") - merged_community_reports = await _update_community_reports( - previous_storage, delta_storage, output_storage, community_id_mapping - ) - - # Generate text embeddings - progress_logger.info("Updating Text Embeddings") - embedded_fields = get_embedded_fields(config) - text_embed = get_embedding_settings(config) - result = await generate_text_embeddings( - documents=final_documents_df, - relationships=merged_relationships_df, - text_units=merged_text_units, - entities=merged_entities_df, - community_reports=merged_community_reports, - callbacks=callbacks, - cache=cache, - text_embed_config=text_embed, - embedded_fields=embedded_fields, - ) - if config.snapshots.embeddings: - for name, table in result.items(): - await write_table_to_storage( - table, - f"embeddings.{name}", - output_storage, - ) - - -async def _update_community_reports( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, - community_id_mapping: dict, -) -> pd.DataFrame: - """Update the community reports output.""" - old_community_reports = await load_table_from_storage( - "community_reports", previous_storage - ) - delta_community_reports = await load_table_from_storage( - "community_reports", delta_storage - ) - merged_community_reports = _update_and_merge_community_reports( - old_community_reports, delta_community_reports, community_id_mapping - ) - - await write_table_to_storage( - merged_community_reports, "community_reports", output_storage - ) - - return merged_community_reports - - -async def _update_communities( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, -) -> dict: - """Update the communities output.""" - old_communities = await load_table_from_storage("communities", previous_storage) - delta_communities = await load_table_from_storage("communities", delta_storage) - merged_communities, community_id_mapping = _update_and_merge_communities( - old_communities, delta_communities - ) - - await write_table_to_storage(merged_communities, "communities", output_storage) - - return community_id_mapping - - -async def _update_covariates( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, -) -> None: - """Update the covariates output.""" - old_covariates = await load_table_from_storage("covariates", previous_storage) - delta_covariates = await load_table_from_storage("covariates", delta_storage) - merged_covariates = _merge_covariates(old_covariates, delta_covariates) - - await write_table_to_storage(merged_covariates, "covariates", output_storage) - - -async def _update_text_units( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, - entity_id_mapping: dict, -) -> pd.DataFrame: - """Update the text units output.""" - old_text_units = await load_table_from_storage("text_units", previous_storage) - delta_text_units = await load_table_from_storage("text_units", delta_storage) - merged_text_units = _update_and_merge_text_units( - old_text_units, delta_text_units, entity_id_mapping - ) - - await write_table_to_storage(merged_text_units, "text_units", output_storage) - - return merged_text_units - - -async def _update_entities_and_relationships( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, - config: GraphRagConfig, - cache: PipelineCache, - callbacks: WorkflowCallbacks, -) -> tuple[pd.DataFrame, pd.DataFrame, dict]: - """Update Final Entities and Relationships output.""" - old_entities = await load_table_from_storage("entities", previous_storage) - delta_entities = await load_table_from_storage("entities", delta_storage) - - merged_entities_df, entity_id_mapping = _group_and_resolve_entities( - old_entities, delta_entities - ) - - # Update Relationships - old_relationships = await load_table_from_storage("relationships", previous_storage) - delta_relationships = await load_table_from_storage("relationships", delta_storage) - merged_relationships_df = _update_and_merge_relationships( - old_relationships, - delta_relationships, - ) - - summarization_llm_settings = config.get_language_model_config( - config.summarize_descriptions.model_id - ) - summarization_strategy = config.summarize_descriptions.resolved_strategy( - config.root_dir, summarization_llm_settings - ) - - ( - merged_entities_df, - merged_relationships_df, - ) = await get_summarized_entities_relationships( - extracted_entities=merged_entities_df, - extracted_relationships=merged_relationships_df, - callbacks=callbacks, - cache=cache, - summarization_strategy=summarization_strategy, - summarization_num_threads=summarization_llm_settings.concurrent_requests, - ) - - # Save the updated entities back to storage - await write_table_to_storage(merged_entities_df, "entities", output_storage) - - await write_table_to_storage( - merged_relationships_df, "relationships", output_storage - ) - - return merged_entities_df, merged_relationships_df, entity_id_mapping - - -async def _concat_dataframes( +async def concat_dataframes( name: str, previous_storage: PipelineStorage, delta_storage: PipelineStorage, @@ -306,65 +81,3 @@ async def _concat_dataframes( await write_table_to_storage(final_df, name, output_storage) return final_df - - -def _update_and_merge_text_units( - old_text_units: pd.DataFrame, - delta_text_units: pd.DataFrame, - entity_id_mapping: dict, -) -> pd.DataFrame: - """Update and merge text units. - - Parameters - ---------- - old_text_units : pd.DataFrame - The old text units. - delta_text_units : pd.DataFrame - The delta text units. - entity_id_mapping : dict - The entity id mapping. - - Returns - ------- - pd.DataFrame - The updated text units. - """ - # Look for entity ids in entity_ids and replace them with the corresponding id in the mapping - if entity_id_mapping: - delta_text_units["entity_ids"] = delta_text_units["entity_ids"].apply( - lambda x: [entity_id_mapping.get(i, i) for i in x] if x is not None else x - ) - - initial_id = old_text_units["human_readable_id"].max() + 1 - delta_text_units["human_readable_id"] = np.arange( - initial_id, initial_id + len(delta_text_units) - ) - # Merge the final text units - return pd.concat([old_text_units, delta_text_units], ignore_index=True, copy=False) - - -def _merge_covariates( - old_covariates: pd.DataFrame, delta_covariates: pd.DataFrame -) -> pd.DataFrame: - """Merge the covariates. - - Parameters - ---------- - old_covariates : pd.DataFrame - The old covariates. - delta_covariates : pd.DataFrame - The delta covariates. - - Returns - ------- - pd.DataFrame - The merged covariates. - """ - # Get the max human readable id from the old covariates and update the delta covariates - initial_id = old_covariates["human_readable_id"].max() + 1 - delta_covariates["human_readable_id"] = np.arange( - initial_id, initial_id + len(delta_covariates) - ) - - # Concatenate the old and delta covariates - return pd.concat([old_covariates, delta_covariates], ignore_index=True, copy=False) diff --git a/graphrag/index/workflows/__init__.py b/graphrag/index/workflows/__init__.py index 425639be0b..7f38a3cc63 100644 --- a/graphrag/index/workflows/__init__.py +++ b/graphrag/index/workflows/__init__.py @@ -42,6 +42,30 @@ from .prune_graph import ( run_workflow as run_prune_graph, ) +from .update_clean_state import ( + run_workflow as run_update_clean_state, +) +from .update_communities import ( + run_workflow as run_update_communities, +) +from .update_community_reports import ( + run_workflow as run_update_community_reports, +) +from .update_covariates import ( + run_workflow as run_update_covariates, +) +from .update_entities_relationships import ( + run_workflow as run_update_entities_relationships, +) +from .update_final_documents import ( + run_workflow as run_update_final_documents, +) +from .update_text_embeddings import ( + run_workflow as run_update_text_embeddings, +) +from .update_text_units import ( + run_workflow as run_update_text_units, +) # register all of our built-in workflows at once PipelineFactory.register_all({ @@ -57,4 +81,12 @@ "finalize_graph": run_finalize_graph, "generate_text_embeddings": run_generate_text_embeddings, "prune_graph": run_prune_graph, + "update_final_documents": run_update_final_documents, + "update_text_embeddings": run_update_text_embeddings, + "update_community_reports": run_update_community_reports, + "update_entities_relationships": run_update_entities_relationships, + "update_communities": run_update_communities, + "update_covariates": run_update_covariates, + "update_text_units": run_update_text_units, + "update_clean_state": run_update_clean_state, }) diff --git a/graphrag/index/workflows/factory.py b/graphrag/index/workflows/factory.py index b68ccf55e0..c73e64b66d 100644 --- a/graphrag/index/workflows/factory.py +++ b/graphrag/index/workflows/factory.py @@ -29,19 +29,35 @@ def register_all(cls, workflows: dict[str, WorkflowFunction]): @classmethod def create_pipeline( - cls, config: GraphRagConfig, method: IndexingMethod = IndexingMethod.Standard + cls, + config: GraphRagConfig, + method: IndexingMethod = IndexingMethod.Standard, + is_update_run: bool = False, ) -> Pipeline: """Create a pipeline generator.""" - workflows = _get_workflows_list(config, method) + workflows = _get_workflows_list(config, method, is_update_run) return Pipeline([(name, cls.workflows[name]) for name in workflows]) def _get_workflows_list( - config: GraphRagConfig, method: IndexingMethod = IndexingMethod.Standard + config: GraphRagConfig, + method: IndexingMethod = IndexingMethod.Standard, + is_update_run: bool = False, ) -> list[str]: """Return a list of workflows for the indexing pipeline.""" + update_workflows = [ + "update_final_documents", + "update_entities_relationships", + "update_text_units", + "update_covariates", + "update_communities", + "update_community_reports", + "update_text_embeddings", + "update_clean_state", + ] if config.workflows: return config.workflows + match method: case IndexingMethod.Standard: return [ @@ -54,6 +70,7 @@ def _get_workflows_list( "create_final_text_units", "create_community_reports", "generate_text_embeddings", + *(update_workflows if is_update_run else []), ] case IndexingMethod.Fast: return [ @@ -66,4 +83,5 @@ def _get_workflows_list( "create_final_text_units", "create_community_reports_text", "generate_text_embeddings", + *(update_workflows if is_update_run else []), ] diff --git a/graphrag/index/workflows/update_clean_state.py b/graphrag/index/workflows/update_clean_state.py new file mode 100644 index 0000000000..7739595a41 --- /dev/null +++ b/graphrag/index/workflows/update_clean_state.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput + +logger = logging.getLogger(__name__) + + +async def run_workflow( # noqa: RUF029 + _config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Clean the state after the update.""" + logger.info("Cleaning State") + keys_to_delete = [ + key_name + for key_name in context.state + if key_name.startswith("incremental_update_") + ] + + for key_name in keys_to_delete: + del context.state[key_name] + + return WorkflowFunctionOutput(result=None) diff --git a/graphrag/index/workflows/update_communities.py b/graphrag/index/workflows/update_communities.py new file mode 100644 index 0000000000..14c8826b75 --- /dev/null +++ b/graphrag/index/workflows/update_communities.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.run.utils import get_update_storages +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.index.update.communities import _update_and_merge_communities +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Update the communities from a incremental index run.""" + logger.info("Updating Communities") + output_storage, previous_storage, delta_storage = get_update_storages( + config, context.state["update_timestamp"] + ) + + community_id_mapping = await _update_communities( + previous_storage, delta_storage, output_storage + ) + + context.state["incremental_update_community_id_mapping"] = community_id_mapping + + return WorkflowFunctionOutput(result=None) + + +async def _update_communities( + previous_storage: PipelineStorage, + delta_storage: PipelineStorage, + output_storage: PipelineStorage, +) -> dict: + """Update the communities output.""" + old_communities = await load_table_from_storage("communities", previous_storage) + delta_communities = await load_table_from_storage("communities", delta_storage) + merged_communities, community_id_mapping = _update_and_merge_communities( + old_communities, delta_communities + ) + + await write_table_to_storage(merged_communities, "communities", output_storage) + + return community_id_mapping diff --git a/graphrag/index/workflows/update_community_reports.py b/graphrag/index/workflows/update_community_reports.py new file mode 100644 index 0000000000..2dc0feb3a2 --- /dev/null +++ b/graphrag/index/workflows/update_community_reports.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +import pandas as pd + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.run.utils import get_update_storages +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.index.update.communities import _update_and_merge_community_reports +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Update the community reports from a incremental index run.""" + logger.info("Updating Community Reports") + output_storage, previous_storage, delta_storage = get_update_storages( + config, context.state["update_timestamp"] + ) + + community_id_mapping = context.state["incremental_update_community_id_mapping"] + + merged_community_reports = await _update_community_reports( + previous_storage, delta_storage, output_storage, community_id_mapping + ) + + context.state["incremental_update_merged_community_reports"] = ( + merged_community_reports + ) + + return WorkflowFunctionOutput(result=None) + + +async def _update_community_reports( + previous_storage: PipelineStorage, + delta_storage: PipelineStorage, + output_storage: PipelineStorage, + community_id_mapping: dict, +) -> pd.DataFrame: + """Update the community reports output.""" + old_community_reports = await load_table_from_storage( + "community_reports", previous_storage + ) + delta_community_reports = await load_table_from_storage( + "community_reports", delta_storage + ) + merged_community_reports = _update_and_merge_community_reports( + old_community_reports, delta_community_reports, community_id_mapping + ) + + await write_table_to_storage( + merged_community_reports, "community_reports", output_storage + ) + + return merged_community_reports diff --git a/graphrag/index/workflows/update_covariates.py b/graphrag/index/workflows/update_covariates.py new file mode 100644 index 0000000000..1239de144a --- /dev/null +++ b/graphrag/index/workflows/update_covariates.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +import numpy as np +import pandas as pd + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.run.utils import get_update_storages +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import ( + load_table_from_storage, + storage_has_table, + write_table_to_storage, +) + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Update the covariates from a incremental index run.""" + output_storage, previous_storage, delta_storage = get_update_storages( + config, context.state["update_timestamp"] + ) + + if await storage_has_table( + "covariates", previous_storage + ) and await storage_has_table("covariates", delta_storage): + logger.info("Updating Covariates") + await _update_covariates(previous_storage, delta_storage, output_storage) + + return WorkflowFunctionOutput(result=None) + + +async def _update_covariates( + previous_storage: PipelineStorage, + delta_storage: PipelineStorage, + output_storage: PipelineStorage, +) -> None: + """Update the covariates output.""" + old_covariates = await load_table_from_storage("covariates", previous_storage) + delta_covariates = await load_table_from_storage("covariates", delta_storage) + merged_covariates = _merge_covariates(old_covariates, delta_covariates) + + await write_table_to_storage(merged_covariates, "covariates", output_storage) + + +def _merge_covariates( + old_covariates: pd.DataFrame, delta_covariates: pd.DataFrame +) -> pd.DataFrame: + """Merge the covariates. + + Parameters + ---------- + old_covariates : pd.DataFrame + The old covariates. + delta_covariates : pd.DataFrame + The delta covariates. + + Returns + ------- + pd.DataFrame + The merged covariates. + """ + # Get the max human readable id from the old covariates and update the delta covariates + initial_id = old_covariates["human_readable_id"].max() + 1 + delta_covariates["human_readable_id"] = np.arange( + initial_id, initial_id + len(delta_covariates) + ) + + # Concatenate the old and delta covariates + return pd.concat([old_covariates, delta_covariates], ignore_index=True, copy=False) diff --git a/graphrag/index/workflows/update_entities_relationships.py b/graphrag/index/workflows/update_entities_relationships.py new file mode 100644 index 0000000000..0702d62776 --- /dev/null +++ b/graphrag/index/workflows/update_entities_relationships.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +import pandas as pd + +from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.run.utils import get_update_storages +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.index.update.entities import _group_and_resolve_entities +from graphrag.index.update.relationships import _update_and_merge_relationships +from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Update the entities and relationships from a incremental index run.""" + logger.info("Updating Entities and Relationships") + output_storage, previous_storage, delta_storage = get_update_storages( + config, context.state["update_timestamp"] + ) + + ( + merged_entities_df, + merged_relationships_df, + entity_id_mapping, + ) = await _update_entities_and_relationships( + previous_storage, + delta_storage, + output_storage, + config, + context.cache, + context.callbacks, + ) + + context.state["incremental_update_merged_entities"] = merged_entities_df + context.state["incremental_update_merged_relationships"] = merged_relationships_df + context.state["incremental_update_entity_id_mapping"] = entity_id_mapping + + return WorkflowFunctionOutput(result=None) + + +async def _update_entities_and_relationships( + previous_storage: PipelineStorage, + delta_storage: PipelineStorage, + output_storage: PipelineStorage, + config: GraphRagConfig, + cache: PipelineCache, + callbacks: WorkflowCallbacks, +) -> tuple[pd.DataFrame, pd.DataFrame, dict]: + """Update Final Entities and Relationships output.""" + old_entities = await load_table_from_storage("entities", previous_storage) + delta_entities = await load_table_from_storage("entities", delta_storage) + + merged_entities_df, entity_id_mapping = _group_and_resolve_entities( + old_entities, delta_entities + ) + + # Update Relationships + old_relationships = await load_table_from_storage("relationships", previous_storage) + delta_relationships = await load_table_from_storage("relationships", delta_storage) + merged_relationships_df = _update_and_merge_relationships( + old_relationships, + delta_relationships, + ) + + summarization_llm_settings = config.get_language_model_config( + config.summarize_descriptions.model_id + ) + summarization_strategy = config.summarize_descriptions.resolved_strategy( + config.root_dir, summarization_llm_settings + ) + + ( + merged_entities_df, + merged_relationships_df, + ) = await get_summarized_entities_relationships( + extracted_entities=merged_entities_df, + extracted_relationships=merged_relationships_df, + callbacks=callbacks, + cache=cache, + summarization_strategy=summarization_strategy, + summarization_num_threads=summarization_llm_settings.concurrent_requests, + ) + + # Save the updated entities back to storage + await write_table_to_storage(merged_entities_df, "entities", output_storage) + + await write_table_to_storage( + merged_relationships_df, "relationships", output_storage + ) + + return merged_entities_df, merged_relationships_df, entity_id_mapping diff --git a/graphrag/index/workflows/update_final_documents.py b/graphrag/index/workflows/update_final_documents.py new file mode 100644 index 0000000000..485cb7b10f --- /dev/null +++ b/graphrag/index/workflows/update_final_documents.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.run.utils import get_update_storages +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.index.update.incremental_index import concat_dataframes + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Update the documents from a incremental index run.""" + logger.info("Updating Documents") + output_storage, previous_storage, delta_storage = get_update_storages( + config, context.state["update_timestamp"] + ) + + final_documents = await concat_dataframes( + "documents", previous_storage, delta_storage, output_storage + ) + + context.state["incremental_update_final_documents"] = final_documents + + return WorkflowFunctionOutput(result=None) diff --git a/graphrag/index/workflows/update_text_embeddings.py b/graphrag/index/workflows/update_text_embeddings.py new file mode 100644 index 0000000000..c20fb1bf04 --- /dev/null +++ b/graphrag/index/workflows/update_text_embeddings.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.run.utils import get_update_storages +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings +from graphrag.utils.storage import write_table_to_storage + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Update the text embeddings from a incremental index run.""" + logger.info("Updating Text Embeddings") + output_storage, _, _ = get_update_storages( + config, context.state["update_timestamp"] + ) + + final_documents_df = context.state["incremental_update_final_documents"] + merged_relationships_df = context.state["incremental_update_merged_relationships"] + merged_text_units = context.state["incremental_update_merged_text_units"] + merged_entities_df = context.state["incremental_update_merged_entities"] + merged_community_reports = context.state[ + "incremental_update_merged_community_reports" + ] + + embedded_fields = get_embedded_fields(config) + text_embed = get_embedding_settings(config) + result = await generate_text_embeddings( + documents=final_documents_df, + relationships=merged_relationships_df, + text_units=merged_text_units, + entities=merged_entities_df, + community_reports=merged_community_reports, + callbacks=context.callbacks, + cache=context.cache, + text_embed_config=text_embed, + embedded_fields=embedded_fields, + ) + if config.snapshots.embeddings: + for name, table in result.items(): + await write_table_to_storage( + table, + f"embeddings.{name}", + output_storage, + ) + + return WorkflowFunctionOutput(result=None) diff --git a/graphrag/index/workflows/update_text_units.py b/graphrag/index/workflows/update_text_units.py new file mode 100644 index 0000000000..4b26b47b07 --- /dev/null +++ b/graphrag/index/workflows/update_text_units.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +import numpy as np +import pandas as pd + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.run.utils import get_update_storages +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +logger = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Update the text units from a incremental index run.""" + logger.info("Updating Text Units") + output_storage, previous_storage, delta_storage = get_update_storages( + config, context.state["update_timestamp"] + ) + entity_id_mapping = context.state["incremental_update_entity_id_mapping"] + + merged_text_units = await _update_text_units( + previous_storage, delta_storage, output_storage, entity_id_mapping + ) + + context.state["incremental_update_merged_text_units"] = merged_text_units + + return WorkflowFunctionOutput(result=None) + + +async def _update_text_units( + previous_storage: PipelineStorage, + delta_storage: PipelineStorage, + output_storage: PipelineStorage, + entity_id_mapping: dict, +) -> pd.DataFrame: + """Update the text units output.""" + old_text_units = await load_table_from_storage("text_units", previous_storage) + delta_text_units = await load_table_from_storage("text_units", delta_storage) + merged_text_units = _update_and_merge_text_units( + old_text_units, delta_text_units, entity_id_mapping + ) + + await write_table_to_storage(merged_text_units, "text_units", output_storage) + + return merged_text_units + + +def _update_and_merge_text_units( + old_text_units: pd.DataFrame, + delta_text_units: pd.DataFrame, + entity_id_mapping: dict, +) -> pd.DataFrame: + """Update and merge text units. + + Parameters + ---------- + old_text_units : pd.DataFrame + The old text units. + delta_text_units : pd.DataFrame + The delta text units. + entity_id_mapping : dict + The entity id mapping. + + Returns + ------- + pd.DataFrame + The updated text units. + """ + # Look for entity ids in entity_ids and replace them with the corresponding id in the mapping + if entity_id_mapping: + delta_text_units["entity_ids"] = delta_text_units["entity_ids"].apply( + lambda x: [entity_id_mapping.get(i, i) for i in x] if x is not None else x + ) + + initial_id = old_text_units["human_readable_id"].max() + 1 + delta_text_units["human_readable_id"] = np.arange( + initial_id, initial_id + len(delta_text_units) + ) + # Merge the final text units + return pd.concat([old_text_units, delta_text_units], ignore_index=True, copy=False)