diff --git a/.semversioner/next-release/patch-20260223133523034773.json b/.semversioner/next-release/patch-20260223133523034773.json new file mode 100644 index 0000000000..67ae560b15 --- /dev/null +++ b/.semversioner/next-release/patch-20260223133523034773.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "create_final_documents streaming" +} diff --git a/packages/graphrag/graphrag/index/workflows/create_final_documents.py b/packages/graphrag/graphrag/index/workflows/create_final_documents.py index ccbd967821..7b3f65f991 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_documents.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_documents.py @@ -1,14 +1,17 @@ -# Copyright (c) 2024 Microsoft Corporation. +# Copyright (C) 2026 Microsoft # Licensed under the MIT License -"""A module containing run_workflow method definition.""" +"""Workflow to create final documents with text unit mappings.""" import logging +from typing import Any -import pandas as pd +from graphrag_storage.tables.table import Table from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.data_model.data_reader import DataReader +from graphrag.data_model.row_transformers import ( + transform_document_row, +) from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -20,49 +23,51 @@ async def run_workflow( _config: GraphRagConfig, context: PipelineRunContext, ) -> WorkflowFunctionOutput: - """All the steps to transform final documents.""" + """Transform final documents via streaming Table reads/writes.""" logger.info("Workflow started: create_final_documents") - reader = DataReader(context.output_table_provider) - documents = await reader.documents() - text_units = await reader.text_units() - output = create_final_documents(documents, text_units) - - await context.output_table_provider.write_dataframe("documents", output) + async with ( + context.output_table_provider.open( + "text_units", + ) as text_units_table, + context.output_table_provider.open( + "documents", + transformer=transform_document_row, + ) as documents_table, + context.output_table_provider.open( + "documents", + ) as output_table, + ): + sample = await create_final_documents( + text_units_table, + documents_table, + output_table, + ) logger.info("Workflow completed: create_final_documents") - return WorkflowFunctionOutput(result=output) - - -def create_final_documents( - documents: pd.DataFrame, text_units: pd.DataFrame -) -> pd.DataFrame: - """All the steps to transform final documents.""" - renamed = text_units.loc[:, ["id", "document_id", "text"]].rename( - columns={ - "document_id": "chunk_doc_id", - "id": "chunk_id", - "text": "chunk_text", - } - ) - - joined = renamed.merge( - documents, - left_on="chunk_doc_id", - right_on="id", - how="inner", - copy=False, - ) - - docs_with_text_units = joined.groupby("id", sort=False).agg( - text_unit_ids=("chunk_id", list) - ) - - rejoined = docs_with_text_units.merge( - documents, - on="id", - how="right", - copy=False, - ).reset_index(drop=True) - - return rejoined.loc[:, DOCUMENTS_FINAL_COLUMNS] + return WorkflowFunctionOutput(result=sample) + + +async def create_final_documents( + text_units_table: Table, + documents_table: Table, + output_table: Table, +) -> list[dict[str, Any]]: + """Build text-unit mapping, then stream-enrich documents.""" + mapping: dict[str, list[str]] = {} + async for row in text_units_table: + document_id = row.get("document_id", "") + if document_id: + mapping.setdefault(document_id, []).append( + row["id"], + ) + + sample_rows: list[dict[str, Any]] = [] + async for row in documents_table: + row["text_unit_ids"] = mapping.get(row["id"], []) + out = {c: row.get(c) for c in DOCUMENTS_FINAL_COLUMNS} + await output_table.write(out) + if len(sample_rows) < 5: + sample_rows.append(out) + + return sample_rows