Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260223133523034773.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "create_final_documents streaming"
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright (c) 2024 Microsoft Corporation.
# Copyright (C) 2026 Microsoft
# Licensed under the MIT License

"""A module containing run_workflow method definition."""
"""Workflow to create final documents with text unit mappings."""

import logging
from typing import Any

import pandas as pd
from graphrag_storage.tables.table import Table

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader
from graphrag.data_model.row_transformers import (
transform_document_row,
)
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
Expand All @@ -20,49 +23,51 @@ async def run_workflow(
_config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform final documents."""
"""Transform final documents via streaming Table reads/writes."""
logger.info("Workflow started: create_final_documents")
reader = DataReader(context.output_table_provider)
documents = await reader.documents()
text_units = await reader.text_units()

output = create_final_documents(documents, text_units)

await context.output_table_provider.write_dataframe("documents", output)
async with (
context.output_table_provider.open(
"text_units",
) as text_units_table,
context.output_table_provider.open(
"documents",
transformer=transform_document_row,
) as documents_table,
context.output_table_provider.open(
"documents",
) as output_table,
):
sample = await create_final_documents(
text_units_table,
documents_table,
output_table,
)

logger.info("Workflow completed: create_final_documents")
return WorkflowFunctionOutput(result=output)


def create_final_documents(
documents: pd.DataFrame, text_units: pd.DataFrame
) -> pd.DataFrame:
"""All the steps to transform final documents."""
renamed = text_units.loc[:, ["id", "document_id", "text"]].rename(
columns={
"document_id": "chunk_doc_id",
"id": "chunk_id",
"text": "chunk_text",
}
)

joined = renamed.merge(
documents,
left_on="chunk_doc_id",
right_on="id",
how="inner",
copy=False,
)

docs_with_text_units = joined.groupby("id", sort=False).agg(
text_unit_ids=("chunk_id", list)
)

rejoined = docs_with_text_units.merge(
documents,
on="id",
how="right",
copy=False,
).reset_index(drop=True)

return rejoined.loc[:, DOCUMENTS_FINAL_COLUMNS]
return WorkflowFunctionOutput(result=sample)


async def create_final_documents(
text_units_table: Table,
documents_table: Table,
output_table: Table,
) -> list[dict[str, Any]]:
"""Build text-unit mapping, then stream-enrich documents."""
mapping: dict[str, list[str]] = {}
async for row in text_units_table:
document_id = row.get("document_id", "")
if document_id:
mapping.setdefault(document_id, []).append(
row["id"],
)

sample_rows: list[dict[str, Any]] = []
async for row in documents_table:
row["text_unit_ids"] = mapping.get(row["id"], [])
out = {c: row.get(c) for c in DOCUMENTS_FINAL_COLUMNS}
await output_table.write(out)
if len(sample_rows) < 5:
sample_rows.append(out)

return sample_rows