From 5a04836e0fcb34b8fa299e5b0479a990c45c0811 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Tue, 10 Feb 2026 21:22:22 +0000 Subject: [PATCH 01/10] add streaming read --- .../graphrag_storage/file_storage.py | 4 + .../graphrag_storage/tables/__init__.py | 3 +- .../graphrag_storage/tables/csv_table.py | 110 +++++++ .../tables/csv_table_provider.py | 15 +- .../tables/parquet_table_provider.py | 18 ++ .../graphrag_storage/tables/table.py | 116 +++++++ .../graphrag_storage/tables/table_provider.py | 21 ++ .../graphrag/index/workflows/__init__.py | 3 + .../index/workflows/create_base_text_units.py | 173 +++++++---- .../graphrag/prompt_tune/loader/input.py | 17 +- tests/unit/prompt_tune/__init__.py | 4 + .../prompt_tune/test_load_docs_in_chunks.py | 292 ++++++++++++++++++ 12 files changed, 700 insertions(+), 76 deletions(-) create mode 100644 packages/graphrag-storage/graphrag_storage/tables/csv_table.py create mode 100644 packages/graphrag-storage/graphrag_storage/tables/table.py create mode 100644 tests/unit/prompt_tune/__init__.py create mode 100644 tests/unit/prompt_tune/test_load_docs_in_chunks.py diff --git a/packages/graphrag-storage/graphrag_storage/file_storage.py b/packages/graphrag-storage/graphrag_storage/file_storage.py index 547659abcd..7eb89dcc2e 100644 --- a/packages/graphrag-storage/graphrag_storage/file_storage.py +++ b/packages/graphrag-storage/graphrag_storage/file_storage.py @@ -144,6 +144,10 @@ async def get_creation_date(self, key: str) -> str: return get_timestamp_formatted_with_local_tz(creation_time_utc) + def get_path(self, key: str) -> Path: + """Get the full file path for a key (for streaming access).""" + return _join_path(self._base_dir, key) + def _join_path(file_path: Path, file_name: str) -> Path: """Join a path and a file. Independent of the OS.""" diff --git a/packages/graphrag-storage/graphrag_storage/tables/__init__.py b/packages/graphrag-storage/graphrag_storage/tables/__init__.py index 0210d935f3..9f95b076ca 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/__init__.py +++ b/packages/graphrag-storage/graphrag_storage/tables/__init__.py @@ -3,6 +3,7 @@ """Table provider module for GraphRAG storage.""" +from .table import Table from .table_provider import TableProvider -__all__ = ["TableProvider"] +__all__ = ["Table", "TableProvider"] diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py new file mode 100644 index 0000000000..b643674d28 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT Licenses + +"""A CSV-based implementation of the Table abstraction for streaming row access.""" + +import csv +import inspect +from collections.abc import AsyncIterator +from pathlib import Path +from typing import Any + +import aiofiles + +from graphrag_storage import Storage +from graphrag_storage.file_storage import FileStorage +from graphrag_storage.tables.table import RowTransformer, Table + + +def _identity(row: dict[str, Any]) -> Any: + """Return row unchanged (default transformer).""" + return row + + +def _apply_transformer(transformer: RowTransformer, row: dict[str, Any]) -> Any: + """Apply transformer to row, handling both callables and classes. + + If transformer is a class (e.g., Pydantic model), calls it with **row. + Otherwise calls it with row as positional argument. + """ + if inspect.isclass(transformer): + return transformer(**row) + return transformer(row) + + +class CSVTable(Table): + """Row-by-row streaming interface for CSV tables.""" + + def __init__( + self, + storage: Storage, + table_name: str, + transformer: RowTransformer | None = None, + ): + """Initialize with storage backend and table name. + + Args: + storage: Storage instance (File, Blob, or Cosmos) + table_name: Name of the table (e.g., "documents") + transformer: Optional callable to transform each row before + yielding. Receives a dict, returns a transformed dict. + Defaults to identity (no transformation). + """ + self._storage = storage + self._table_name = table_name + self._file_key = f"{table_name}.csv" + self._transformer = transformer or _identity + + def __aiter__(self) -> AsyncIterator[Any]: + """Iterate through rows one at a time. + + The transformer is applied to each row before yielding. + If transformer is a Pydantic model, yields model instances. + + Yields + ------ + Any: + Each row as dict or transformed type (e.g., Pydantic model). + """ + return self._aiter_impl() + + async def _aiter_impl(self) -> AsyncIterator[Any]: + """Implement async iteration over rows.""" + if isinstance(self._storage, FileStorage): + file_path = self._storage.get_path(self._file_key) + with Path.open(file_path, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + yield _apply_transformer(self._transformer, row) + + async def length(self) -> int: + """Return the number of rows in the table.""" + if isinstance(self._storage, FileStorage): + file_path = self._storage.get_path(self._file_key) + count = 0 + async with aiofiles.open(file_path, "rb") as f: + while True: + chunk = await f.read(65536) + if not chunk: + break + count += chunk.count(b"\n") + return count - 1 + return 0 + + async def has(self, row_id: str) -> bool: + """Check if row with given ID exists.""" + async for row in self: + # Handle both dict and object (e.g., Pydantic model) + if isinstance(row, dict): + if row.get("id") == row_id: + return True + elif getattr(row, "id", None) == row_id: + return True + return False + + async def close(self) -> None: + """Flush buffered writes and release resources. + + No-op for CSV tables since rows are read on demand + and no persistent connections are held open. + """ diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py index 5de021b8a5..9c116fab09 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Microsoft Corporation. +# Copyright (c) 2025 Microsoft Corporation. # Licensed under the MIT License """CSV-based table provider implementation.""" @@ -9,12 +9,14 @@ import pandas as pd +from graphrag_storage.file_storage import FileStorage from graphrag_storage.storage import Storage +from graphrag_storage.tables.csv_table import CSVTable +from graphrag_storage.tables.table import RowTransformer from graphrag_storage.tables.table_provider import TableProvider logger = logging.getLogger(__name__) - class CSVTableProvider(TableProvider): """Table provider that stores tables as CSV files using an underlying Storage instance. @@ -32,6 +34,9 @@ def __init__(self, storage: Storage, **kwargs) -> None: **kwargs: Any Additional keyword arguments (currently unused). """ + if not isinstance(storage, FileStorage): + msg = "CSVTableProvider only works with FileStorage backends for now. " + raise TypeError(msg) self._storage = storage async def read_dataframe(self, table_name: str) -> pd.DataFrame: @@ -108,3 +113,9 @@ def list(self) -> list[str]: file.replace(".csv", "") for file in self._storage.find(re.compile(r"\.csv$")) ] + + def open( + self, table_name: str, transformer: RowTransformer | None = None + ) -> CSVTable: + """Open table for streaming.""" + return CSVTable(self._storage, table_name, transformer=transformer) diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py index 74f63660dc..add5599839 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -10,6 +10,7 @@ import pandas as pd from graphrag_storage.storage import Storage +from graphrag_storage.tables.table import RowTransformer, Table from graphrag_storage.tables.table_provider import TableProvider logger = logging.getLogger(__name__) @@ -106,3 +107,20 @@ def list(self) -> list[str]: file.replace(".parquet", "") for file in self._storage.find(re.compile(r"\.parquet$")) ] + + def open(self, table_name: str, transformer: RowTransformer | None = None) -> Table: + """Open a table for streaming row operations. + + Not yet implemented for Parquet tables. Parquet format requires + loading data in chunks rather than true row-by-row streaming. + + Raises + ------ + NotImplementedError: + Streaming access is not yet supported for Parquet tables. + """ + msg = ( + "Streaming access not yet implemented for ParquetTableProvider. " + "Use read_dataframe() for bulk access." + ) + raise NotImplementedError(msg) diff --git a/packages/graphrag-storage/graphrag_storage/tables/table.py b/packages/graphrag-storage/graphrag_storage/tables/table.py new file mode 100644 index 0000000000..ca881728e1 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/table.py @@ -0,0 +1,116 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""Table abstraction for streaming row-by-row access.""" + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Callable +from types import TracebackType +from typing import Any + +from typing_extensions import Self + +RowTransformer = Callable[[dict[str, Any]], Any] + + +class Table(ABC): + """Abstract base class for streaming table access. + + Provides row-by-row iteration and write capabilities for memory-efficient + processing of large datasets. Supports async context manager protocol for + automatic resource cleanup. + + Examples + -------- + Reading rows as dicts: + >>> async with ( + ... provider.open( + ... "documents" + ... ) as table + ... ): + ... async for ( + ... row + ... ) in table: + ... process(row) + + With Pydantic model as transformer: + >>> async with ( + ... provider.open( + ... "entities", + ... Entity, + ... ) as table + ... ): + ... async for entity in table: # yields Entity instances + ... print( + ... entity.name + ... ) + """ + + @abstractmethod + def __aiter__(self) -> AsyncIterator[Any]: + """Yield rows asynchronously, transformed if transformer provided. + + Yields + ------ + Any: + Each row, either as dict or transformed type (e.g., Pydantic model). + """ + ... + + @abstractmethod + async def length(self) -> int: + """Return number of rows asynchronously. + + Returns + ------- + int: + Number of rows in the table. + """ + + @abstractmethod + async def has(self, row_id: str) -> bool: + """Check if a row with the given ID exists. + + Args + ---- + row_id: The ID value to search for. + + Returns + ------- + bool: + True if a row with matching ID exists. + """ + + @abstractmethod + async def close(self) -> None: + """Flush buffered writes and release resources. + + This method is called automatically when exiting the async context + manager, but can also be called explicitly. + """ + + async def __aenter__(self) -> Self: + """Enter async context manager. + + Returns + ------- + Table: + Self for context manager usage. + """ + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit async context manager, ensuring close() is called. + + Args + ---- + exc_type: Exception type if an exception occurred + exc_val: Exception value if an exception occurred + exc_tb: Exception traceback if an exception occurred + """ + await self.close() diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py index 07a86c3119..74bb049937 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py @@ -8,6 +8,8 @@ import pandas as pd +from graphrag_storage.tables.table import RowTransformer, Table + class TableProvider(ABC): """Provide a table-based storage interface with support for DataFrames and row dictionaries.""" @@ -73,3 +75,22 @@ def list(self) -> list[str]: list[str]: List of table names (without file extensions). """ + + @abstractmethod + def open( + self, table_name: str, transformer: RowTransformer | None = None + ) -> Table: # Returns Table instance + """Open a table for row-by-row streaming operations. + + Args + ---- + table_name: str + The name of the table to open. + transformer: RowTransformer | None + Optional transformer function to apply to each row. + + Returns + ------- + Table: + A Table instance for streaming row operations. + """ diff --git a/packages/graphrag/graphrag/index/workflows/__init__.py b/packages/graphrag/graphrag/index/workflows/__init__.py index 6dee90c097..354a936374 100644 --- a/packages/graphrag/graphrag/index/workflows/__init__.py +++ b/packages/graphrag/graphrag/index/workflows/__init__.py @@ -6,6 +6,9 @@ from graphrag.index.workflows.factory import PipelineFactory +from .create_base_text_units import ( + chunk_document, +) from .create_base_text_units import ( run_workflow as run_create_base_text_units, ) diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index 196ab3f1b6..8db04df1c5 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -4,7 +4,7 @@ """A module containing run_workflow method definition.""" import logging -from typing import Any, cast +from typing import Any import pandas as pd from graphrag_chunking.chunker import Chunker @@ -12,10 +12,10 @@ from graphrag_chunking.transformers import add_metadata from graphrag_input import TextDocument from graphrag_llm.tokenizer import Tokenizer +from graphrag_storage.tables.table import Table from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.data_model.data_reader import DataReader from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.utils.hashing import gen_sha512_hash @@ -31,18 +31,20 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform base text_units.""" logger.info("Workflow started: create_base_text_units") - reader = DataReader(context.output_table_provider) - documents = await reader.documents() tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model) chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode) - output = create_base_text_units( - documents, - context.callbacks, - tokenizer=tokenizer, - chunker=chunker, - prepend_metadata=config.chunking.prepend_metadata, - ) + + async with context.output_table_provider.open("documents") as documents_table: + total_rows = await documents_table.length() + output = await create_base_text_units( + documents_table, + total_rows, + context.callbacks, + tokenizer=tokenizer, + chunker=chunker, + prepend_metadata=config.chunking.prepend_metadata, + ) await context.output_table_provider.write_dataframe("text_units", output) @@ -50,70 +52,109 @@ async def run_workflow( return WorkflowFunctionOutput(result=output) -def create_base_text_units( - documents: pd.DataFrame, +async def create_base_text_units( + documents_table: Table, + total_rows: int, callbacks: WorkflowCallbacks, tokenizer: Tokenizer, chunker: Chunker, prepend_metadata: list[str] | None = None, ) -> pd.DataFrame: - """All the steps to transform base text_units.""" - documents.sort_values(by=["id"], ascending=[True], inplace=True) - - total_rows = len(documents) + """Transform documents into chunked text units via streaming read. + + Reads documents row-by-row from an async iterable to avoid loading the + entire documents table into memory at once. Chunks each document + individually and collects the resulting text units into a DataFrame. + + Args + ---- + documents_table: Table + Table instance representing the documents. Supports async iteration over document rows. + total_rows: int + Total number of documents for progress reporting. + callbacks: WorkflowCallbacks + Callbacks for progress reporting. + tokenizer: Tokenizer + Tokenizer for measuring chunk token counts. + chunker: Chunker + Chunker instance for splitting document text. + prepend_metadata: list[str] | None + Optional list of metadata fields to prepend to + each chunk. + + Returns + ------- + pd.DataFrame: + DataFrame with columns: id, document_id, text, + n_tokens. + """ tick = progress_ticker(callbacks.progress, total_rows) - # Track progress of row-wise apply operation - logger.info("Starting chunking process for %d documents", total_rows) - - def chunker_with_logging(row: pd.Series, row_index: int) -> Any: - if prepend_metadata: - # create a standard text document for metadata plucking - # ignore any additional fields in case the input dataframe has extra columns - document = TextDocument( - id=row["id"], - title=row["title"], - text=row["text"], - creation_date=row["creation_date"], - raw_data=row["raw_data"], - ) - metadata = document.collect(prepend_metadata) - transformer = add_metadata( - metadata=metadata, line_delimiter=".\n" - ) # delim with . for back-compat older indexes - else: - transformer = None - - row["chunks"] = [ - chunk.text for chunk in chunker.chunk(row["text"], transform=transformer) - ] + logger.info( + "Starting chunking process for %d documents", + total_rows, + ) + rows: list[dict[str, Any]] = [] + doc_index = 0 + + async for doc in documents_table: + chunks = chunk_document(doc, chunker, prepend_metadata) + for chunk_text in chunks: + if chunk_text is None: + continue + row = { + "document_id": doc["id"], + "text": chunk_text, + } + row["id"] = gen_sha512_hash(row, ["text"]) + row["n_tokens"] = len(tokenizer.encode(chunk_text)) + rows.append(row) + + doc_index += 1 tick() - logger.info("chunker progress: %d/%d", row_index + 1, total_rows) - return row + logger.info( + "chunker progress: %d/%d", + doc_index, + total_rows, + ) + # //write - text_units = documents.apply( - lambda row: chunker_with_logging(row, row.name), axis=1 - ) + df = pd.DataFrame(rows, columns=["id", "document_id", "text", "n_tokens"]) + return df.sort_values(by=["document_id", "id"]).reset_index(drop=True) - text_units = cast("pd.DataFrame", text_units[["id", "chunks"]]) - text_units = text_units.explode("chunks") - text_units.rename( - columns={ - "id": "document_id", - "chunks": "text", - }, - inplace=True, - ) - text_units["id"] = text_units.apply( - lambda row: gen_sha512_hash(row, ["text"]), axis=1 - ) - # get a final token measurement - text_units["n_tokens"] = text_units["text"].apply( - lambda x: len(tokenizer.encode(x)) - ) - - return cast( - "pd.DataFrame", text_units[text_units["text"].notna()].reset_index(drop=True) - ) +def chunk_document( + doc: dict[str, Any], + chunker: Chunker, + prepend_metadata: list[str] | None = None, +) -> list[str]: + """Chunk a single document row into text fragments. + + Args + ---- + doc: dict[str, Any] + A single document row as a dictionary. + chunker: Chunker + Chunker instance for splitting text. + prepend_metadata: list[str] | None + Optional metadata fields to prepend. + + Returns + ------- + list[str]: + List of chunk text strings. + """ + transformer = None + if prepend_metadata: + document = TextDocument( + id=doc["id"], + title=doc.get("title", ""), + text=doc["text"], + creation_date=doc.get("creation_date", ""), + raw_data=doc.get("raw_data"), + ) + metadata = document.collect(prepend_metadata) + transformer = add_metadata(metadata=metadata, line_delimiter=".\n") + + return [chunk.text for chunk in chunker.chunk(doc["text"], transform=transformer)] diff --git a/packages/graphrag/graphrag/prompt_tune/loader/input.py b/packages/graphrag/graphrag/prompt_tune/loader/input.py index 0cfdb2299a..fb3be66744 100644 --- a/packages/graphrag/graphrag/prompt_tune/loader/input.py +++ b/packages/graphrag/graphrag/prompt_tune/loader/input.py @@ -3,6 +3,7 @@ """Input loading module.""" +import dataclasses import logging from typing import Any @@ -18,7 +19,7 @@ from graphrag.index.operations.embed_text.run_embed_text import ( run_embed_text, ) -from graphrag.index.workflows.create_base_text_units import create_base_text_units +from graphrag.index.workflows.create_base_text_units import chunk_document from graphrag.prompt_tune.defaults import ( LIMIT, N_SUBSET_MAX, @@ -58,12 +59,14 @@ async def load_docs_in_chunks( input_storage = create_storage(config.input_storage) input_reader = create_input_reader(config.input, input_storage) dataset = await input_reader.read_files() - chunks_df = create_base_text_units( - documents=pd.DataFrame(dataset), - callbacks=NoopWorkflowCallbacks(), - tokenizer=tokenizer, - chunker=chunker, - ) + + all_chunks: list[str] = [] + for doc in dataset: + doc_dict = dataclasses.asdict(doc) + chunks = chunk_document(doc_dict, chunker) + all_chunks.extend(chunks) + + chunks_df = pd.DataFrame({"text": all_chunks}) # Depending on the select method, build the dataset if limit <= 0 or limit > len(chunks_df): diff --git a/tests/unit/prompt_tune/__init__.py b/tests/unit/prompt_tune/__init__.py new file mode 100644 index 0000000000..4d4df03613 --- /dev/null +++ b/tests/unit/prompt_tune/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""Unit tests for prompt_tune module.""" diff --git a/tests/unit/prompt_tune/test_load_docs_in_chunks.py b/tests/unit/prompt_tune/test_load_docs_in_chunks.py new file mode 100644 index 0000000000..6268aef91e --- /dev/null +++ b/tests/unit/prompt_tune/test_load_docs_in_chunks.py @@ -0,0 +1,292 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""Unit tests for load_docs_in_chunks function.""" + +import logging +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from graphrag.prompt_tune.loader.input import load_docs_in_chunks +from graphrag.prompt_tune.types import DocSelectionType + + +@dataclass +class MockTextDocument: + """Mock TextDocument for testing.""" + + id: str + text: str + title: str + creation_date: str + raw_data: dict[str, Any] | None = None + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def encode(self, text: str) -> list[int]: + """Encode text to tokens (simple char-based).""" + return [ord(c) for c in text] + + def decode(self, tokens: list[int]) -> str: + """Decode tokens to text.""" + return "".join(chr(t) for t in tokens) + + +@dataclass +class MockChunk: + """Mock chunk result.""" + + text: str + + +class MockChunker: + """Mock chunker for testing.""" + + def chunk(self, text: str, transform: Any = None) -> list[MockChunk]: + """Split text into sentence-like chunks.""" + sentences = [s.strip() for s in text.split(".") if s.strip()] + return [MockChunk(text=s + ".") for s in sentences] + + +class MockEmbeddingModel: + """Mock embedding model for testing.""" + + def __init__(self): + """Initialize with mock tokenizer.""" + self.tokenizer = MockTokenizer() + + +@pytest.fixture +def mock_config(): + """Create a mock GraphRagConfig.""" + config = MagicMock() + config.embed_text.embedding_model_id = "test-model" + config.embed_text.batch_size = 10 + config.embed_text.batch_max_tokens = 1000 + config.concurrent_requests = 1 + config.get_embedding_model_config.return_value = MagicMock() + return config + + +@pytest.fixture +def mock_logger(): + """Create a mock logger.""" + return logging.getLogger("test") + + +@pytest.fixture +def sample_documents(): + """Create sample documents for testing.""" + return [ + MockTextDocument( + id="doc1", + text="First sentence. Second sentence. Third sentence.", + title="Doc 1", + creation_date="2025-01-01", + ), + MockTextDocument( + id="doc2", + text="Another document. With content.", + title="Doc 2", + creation_date="2025-01-02", + ), + ] + + +class TestLoadDocsInChunks: + """Tests for load_docs_in_chunks function.""" + + @pytest.mark.asyncio + async def test_top_selection_returns_limited_chunks( + self, mock_config, mock_logger, sample_documents + ): + """Test TOP selection method returns the first N chunks.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=2, + logger=mock_logger, + ) + + assert len(result) == 2 + assert result[0] == "First sentence." + assert result[1] == "Second sentence." + + @pytest.mark.asyncio + async def test_random_selection_returns_correct_count( + self, mock_config, mock_logger, sample_documents + ): + """Test RANDOM selection method returns the correct number of chunks.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.RANDOM, + limit=3, + logger=mock_logger, + ) + + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_escapes_braces_in_output(self, mock_config, mock_logger): + """Test that curly braces are escaped for str.format() compatibility.""" + docs_with_braces = [ + MockTextDocument( + id="doc1", + text="Some {latex} content.", + title="Doc 1", + creation_date="2025-01-01", + ), + ] + + mock_reader = AsyncMock() + mock_reader.read_files.return_value = docs_with_braces + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=1, + logger=mock_logger, + ) + + assert len(result) == 1 + assert "{{latex}}" in result[0] + + @pytest.mark.asyncio + async def test_limit_out_of_range_uses_default( + self, mock_config, mock_logger, sample_documents + ): + """Test that invalid limit falls back to default LIMIT.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + patch( + "graphrag.prompt_tune.loader.input.LIMIT", + 3, + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=-1, + logger=mock_logger, + ) + + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_chunks_all_documents( + self, mock_config, mock_logger, sample_documents + ): + """Test that all documents are chunked correctly.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=5, + logger=mock_logger, + ) + + assert len(result) == 5 + assert "First sentence." in result + assert "Another document." in result From b0432a70c83c0768dac479dbac1f20720b372dc7 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Tue, 10 Feb 2026 21:28:00 +0000 Subject: [PATCH 02/10] fix formatting --- .../graphrag_storage/tables/csv_table_provider.py | 1 + packages/graphrag/graphrag/index/workflows/__init__.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py index 9c116fab09..a8d97df2b3 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) + class CSVTableProvider(TableProvider): """Table provider that stores tables as CSV files using an underlying Storage instance. diff --git a/packages/graphrag/graphrag/index/workflows/__init__.py b/packages/graphrag/graphrag/index/workflows/__init__.py index 354a936374..6dee90c097 100644 --- a/packages/graphrag/graphrag/index/workflows/__init__.py +++ b/packages/graphrag/graphrag/index/workflows/__init__.py @@ -6,9 +6,6 @@ from graphrag.index.workflows.factory import PipelineFactory -from .create_base_text_units import ( - chunk_document, -) from .create_base_text_units import ( run_workflow as run_create_base_text_units, ) From 76a80ce986e565aca83ec3bc4601280fba7eb779 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Tue, 10 Feb 2026 23:46:19 +0000 Subject: [PATCH 03/10] create csv streaming write --- .../graphrag_storage/tables/csv_table.py | 44 +++++++++++++++--- .../graphrag_storage/tables/table.py | 9 ++++ .../index/workflows/create_base_text_units.py | 45 ++++++++----------- 3 files changed, 66 insertions(+), 32 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py index b643674d28..2630ac8ad7 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py @@ -3,18 +3,24 @@ """A CSV-based implementation of the Table abstraction for streaming row access.""" +from __future__ import annotations + import csv import inspect -from collections.abc import AsyncIterator from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import aiofiles -from graphrag_storage import Storage from graphrag_storage.file_storage import FileStorage from graphrag_storage.tables.table import RowTransformer, Table +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from io import TextIOWrapper + + from graphrag_storage import Storage + def _identity(row: dict[str, Any]) -> Any: """Return row unchanged (default transformer).""" @@ -54,6 +60,9 @@ def __init__( self._table_name = table_name self._file_key = f"{table_name}.csv" self._transformer = transformer or _identity + self._write_file: TextIOWrapper | None = None + self._writer: csv.DictWriter | None = None + self._header_written = False def __aiter__(self) -> AsyncIterator[Any]: """Iterate through rows one at a time. @@ -102,9 +111,34 @@ async def has(self, row_id: str) -> bool: return True return False + async def write(self, row: dict[str, Any]) -> None: + """Write a single row to the CSV file. + + On first write, opens the file and writes the header row. + Subsequent writes append rows to the file. + + Args + ---- + row: Dictionary representing a single row to write. + """ + if isinstance(self._storage, FileStorage) and self._write_file is None: + file_path = self._storage.get_path(self._file_key) + file_path.parent.mkdir(parents=True, exist_ok=True) + self._write_file = Path.open(file_path, "w", encoding="utf-8", newline="") + self._writer = csv.DictWriter(self._write_file, fieldnames=list(row.keys())) + self._writer.writeheader() + self._header_written = True + + if self._writer is not None: + self._writer.writerow(row) + async def close(self) -> None: """Flush buffered writes and release resources. - No-op for CSV tables since rows are read on demand - and no persistent connections are held open. + Closes the file handle if writing was performed. """ + if self._write_file is not None: + self._write_file.close() + self._write_file = None + self._writer = None + self._header_written = False diff --git a/packages/graphrag-storage/graphrag_storage/tables/table.py b/packages/graphrag-storage/graphrag_storage/tables/table.py index ca881728e1..d845fb8b3f 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table.py @@ -81,6 +81,15 @@ async def has(self, row_id: str) -> bool: True if a row with matching ID exists. """ + @abstractmethod + async def write(self, row: dict[str, Any]) -> None: + """Write a single row to the table. + + Args + ---- + row: Dictionary representing a single row to write. + """ + @abstractmethod async def close(self) -> None: """Flush buffered writes and release resources. diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index 8db04df1c5..1aa012bb78 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -6,7 +6,6 @@ import logging from typing import Any -import pandas as pd from graphrag_chunking.chunker import Chunker from graphrag_chunking.chunker_factory import create_chunker from graphrag_chunking.transformers import add_metadata @@ -35,10 +34,14 @@ async def run_workflow( tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model) chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode) - async with context.output_table_provider.open("documents") as documents_table: + async with ( + context.output_table_provider.open("documents") as documents_table, + context.output_table_provider.open("text_units") as text_units_table, + ): total_rows = await documents_table.length() - output = await create_base_text_units( + await create_base_text_units( documents_table, + text_units_table, total_rows, context.callbacks, tokenizer=tokenizer, @@ -46,30 +49,30 @@ async def run_workflow( prepend_metadata=config.chunking.prepend_metadata, ) - await context.output_table_provider.write_dataframe("text_units", output) - logger.info("Workflow completed: create_base_text_units") - return WorkflowFunctionOutput(result=output) + return WorkflowFunctionOutput(result=None) async def create_base_text_units( documents_table: Table, + text_units_table: Table, total_rows: int, callbacks: WorkflowCallbacks, tokenizer: Tokenizer, chunker: Chunker, prepend_metadata: list[str] | None = None, -) -> pd.DataFrame: - """Transform documents into chunked text units via streaming read. +) -> None: + """Transform documents into chunked text units via streaming read/write. - Reads documents row-by-row from an async iterable to avoid loading the - entire documents table into memory at once. Chunks each document - individually and collects the resulting text units into a DataFrame. + Reads documents row-by-row from an async iterable and writes text units + directly to the output table, avoiding loading all data into memory. Args ---- documents_table: Table - Table instance representing the documents. Supports async iteration over document rows. + Table instance for reading documents. Supports async iteration. + text_units_table: Table + Table instance for writing text units row by row. total_rows: int Total number of documents for progress reporting. callbacks: WorkflowCallbacks @@ -81,12 +84,6 @@ async def create_base_text_units( prepend_metadata: list[str] | None Optional list of metadata fields to prepend to each chunk. - - Returns - ------- - pd.DataFrame: - DataFrame with columns: id, document_id, text, - n_tokens. """ tick = progress_ticker(callbacks.progress, total_rows) @@ -95,7 +92,6 @@ async def create_base_text_units( total_rows, ) - rows: list[dict[str, Any]] = [] doc_index = 0 async for doc in documents_table: @@ -104,12 +100,13 @@ async def create_base_text_units( if chunk_text is None: continue row = { + "id": "", "document_id": doc["id"], "text": chunk_text, + "n_tokens": len(tokenizer.encode(chunk_text)), } row["id"] = gen_sha512_hash(row, ["text"]) - row["n_tokens"] = len(tokenizer.encode(chunk_text)) - rows.append(row) + await text_units_table.write(row) doc_index += 1 tick() @@ -118,12 +115,6 @@ async def create_base_text_units( doc_index, total_rows, ) - # //write - - df = pd.DataFrame(rows, columns=["id", "document_id", "text", "n_tokens"]) - return df.sort_values(by=["document_id", "id"]).reset_index(drop=True) - - def chunk_document( doc: dict[str, Any], chunker: Chunker, From b9c5b51a5dff435d1b54ca5c3c0523af532dd1ca Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Tue, 10 Feb 2026 23:49:56 +0000 Subject: [PATCH 04/10] add parquettable to simulate streaming --- .../graphrag_storage/tables/parquet_table.py | 141 ++++++++++++++++++ .../tables/parquet_table_provider.py | 26 ++-- .../index/workflows/create_base_text_units.py | 2 + 3 files changed, 158 insertions(+), 11 deletions(-) create mode 100644 packages/graphrag-storage/graphrag_storage/tables/parquet_table.py diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py new file mode 100644 index 0000000000..b743db4296 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py @@ -0,0 +1,141 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""A Parquet-based implementation of the Table abstraction with simulated streaming.""" + +from __future__ import annotations + +import inspect +from io import BytesIO +from typing import TYPE_CHECKING, Any, cast + +import pandas as pd + +from graphrag_storage.tables.table import RowTransformer, Table + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from graphrag_storage.storage import Storage + + +def _identity(row: dict[str, Any]) -> Any: + """Return row unchanged (default transformer).""" + return row + + +def _apply_transformer(transformer: RowTransformer, row: dict[str, Any]) -> Any: + """Apply transformer to row, handling both callables and classes. + + If transformer is a class (e.g., Pydantic model), calls it with **row. + Otherwise calls it with row as positional argument. + """ + if inspect.isclass(transformer): + return transformer(**row) + return transformer(row) + + +class ParquetTable(Table): + """Simulated streaming interface for Parquet tables. + + Parquet format doesn't support true row-by-row streaming, so this + implementation simulates streaming via: + - Read: Loads DataFrame, yields rows via iterrows() + - Write: Accumulates rows in memory, writes all at once on close() + + This provides API compatibility with CSVTable while maintaining + Parquet's performance characteristics for bulk operations. + """ + + def __init__( + self, + storage: Storage, + table_name: str, + transformer: RowTransformer | None = None, + ): + """Initialize with storage backend and table name. + + Args: + storage: Storage instance (File, Blob, or Cosmos) + table_name: Name of the table (e.g., "documents") + transformer: Optional callable to transform each row before + yielding. Receives a dict, returns a transformed dict. + Defaults to identity (no transformation). + """ + self._storage = storage + self._table_name = table_name + self._file_key = f"{table_name}.parquet" + self._transformer = transformer or _identity + self._df: pd.DataFrame | None = None + self._write_rows: list[dict[str, Any]] = [] + + def __aiter__(self) -> AsyncIterator[Any]: + """Iterate through rows one at a time. + + Loads the entire DataFrame on first iteration, then yields rows + one at a time with the transformer applied. + + Yields + ------ + Any: + Each row as dict or transformed type (e.g., Pydantic model). + """ + return self._aiter_impl() + + async def _aiter_impl(self) -> AsyncIterator[Any]: + """Implement async iteration over rows.""" + if self._df is None: + if await self._storage.has(self._file_key): + data = await self._storage.get(self._file_key, as_bytes=True) + self._df = pd.read_parquet(BytesIO(data)) + else: + self._df = pd.DataFrame() + + for _, row in self._df.iterrows(): + row_dict = cast("dict[str, Any]", row.to_dict()) + yield _apply_transformer(self._transformer, row_dict) + + async def length(self) -> int: + """Return the number of rows in the table.""" + if self._df is None: + if await self._storage.has(self._file_key): + data = await self._storage.get(self._file_key, as_bytes=True) + self._df = pd.read_parquet(BytesIO(data)) + else: + return 0 + return len(self._df) + + async def has(self, row_id: str) -> bool: + """Check if row with given ID exists.""" + async for row in self: + if isinstance(row, dict): + if row.get("id") == row_id: + return True + elif getattr(row, "id", None) == row_id: + return True + return False + + async def write(self, row: dict[str, Any]) -> None: + """Accumulate a single row for later batch write. + + Rows are stored in memory and written to Parquet format + when close() is called. + + Args + ---- + row: Dictionary representing a single row to write. + """ + self._write_rows.append(row) + + async def close(self) -> None: + """Flush accumulated rows to Parquet file and release resources. + + Converts all accumulated rows to a DataFrame and writes + to storage as a Parquet file. + """ + if self._write_rows: + df = pd.DataFrame(self._write_rows) + await self._storage.set(self._file_key, df.to_parquet()) + self._write_rows = [] + + self._df = None diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py index add5599839..3f8d3b518d 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -10,6 +10,7 @@ import pandas as pd from graphrag_storage.storage import Storage +from graphrag_storage.tables.parquet_table import ParquetTable from graphrag_storage.tables.table import RowTransformer, Table from graphrag_storage.tables.table_provider import TableProvider @@ -111,16 +112,19 @@ def list(self) -> list[str]: def open(self, table_name: str, transformer: RowTransformer | None = None) -> Table: """Open a table for streaming row operations. - Not yet implemented for Parquet tables. Parquet format requires - loading data in chunks rather than true row-by-row streaming. + Returns a ParquetTable that simulates streaming by loading the + DataFrame and iterating rows, or accumulating writes for batch output. - Raises - ------ - NotImplementedError: - Streaming access is not yet supported for Parquet tables. + Args + ---- + table_name: str + The name of the table to open. + transformer: RowTransformer | None + Optional callable to transform each row on read. + + Returns + ------- + Table: + A ParquetTable instance for row-by-row access. """ - msg = ( - "Streaming access not yet implemented for ParquetTableProvider. " - "Use read_dataframe() for bulk access." - ) - raise NotImplementedError(msg) + return ParquetTable(self._storage, table_name, transformer) diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index 1aa012bb78..1875054f75 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -115,6 +115,8 @@ async def create_base_text_units( doc_index, total_rows, ) + + def chunk_document( doc: dict[str, Any], chunker: Chunker, From bacbab45460448129a7da11128b73d9eaadb33f8 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Wed, 11 Feb 2026 16:40:16 +0000 Subject: [PATCH 05/10] add sample rows return --- .../index/workflows/create_base_text_units.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index 1875054f75..a3d8964255 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -39,7 +39,7 @@ async def run_workflow( context.output_table_provider.open("text_units") as text_units_table, ): total_rows = await documents_table.length() - await create_base_text_units( + sample_rows = await create_base_text_units( documents_table, text_units_table, total_rows, @@ -50,7 +50,7 @@ async def run_workflow( ) logger.info("Workflow completed: create_base_text_units") - return WorkflowFunctionOutput(result=None) + return WorkflowFunctionOutput(result=sample_rows) async def create_base_text_units( @@ -61,7 +61,7 @@ async def create_base_text_units( tokenizer: Tokenizer, chunker: Chunker, prepend_metadata: list[str] | None = None, -) -> None: +) -> list[dict[str, Any]]: """Transform documents into chunked text units via streaming read/write. Reads documents row-by-row from an async iterable and writes text units @@ -93,6 +93,8 @@ async def create_base_text_units( ) doc_index = 0 + sample_rows: list[dict[str, Any]] = [] + sample_size = 5 async for doc in documents_table: chunks = chunk_document(doc, chunker, prepend_metadata) @@ -108,6 +110,9 @@ async def create_base_text_units( row["id"] = gen_sha512_hash(row, ["text"]) await text_units_table.write(row) + if len(sample_rows) < sample_size: + sample_rows.append(row) + doc_index += 1 tick() logger.info( @@ -116,6 +121,8 @@ async def create_base_text_units( total_rows, ) + return sample_rows + def chunk_document( doc: dict[str, Any], From 27b2031547c350d7ede885001d3bdcfbc1bd6a9c Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Thu, 12 Feb 2026 00:08:24 +0000 Subject: [PATCH 06/10] add load_input_documents and truncate mode to Table --- .../graphrag_storage/tables/csv_table.py | 19 ++++++-- .../tables/csv_table_provider.py | 9 +++- .../graphrag_storage/tables/parquet_table.py | 15 ++++-- .../tables/parquet_table_provider.py | 12 ++++- .../graphrag_storage/tables/table_provider.py | 8 +++- .../index/workflows/load_input_documents.py | 46 ++++++++++++------- 6 files changed, 79 insertions(+), 30 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py index 2630ac8ad7..d179688c21 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py @@ -46,6 +46,7 @@ def __init__( storage: Storage, table_name: str, transformer: RowTransformer | None = None, + truncate: bool = True, ): """Initialize with storage backend and table name. @@ -55,11 +56,14 @@ def __init__( transformer: Optional callable to transform each row before yielding. Receives a dict, returns a transformed dict. Defaults to identity (no transformation). + truncate: If True (default), truncate file on first write. + If False, append to existing file. """ self._storage = storage self._table_name = table_name self._file_key = f"{table_name}.csv" self._transformer = transformer or _identity + self._truncate = truncate self._write_file: TextIOWrapper | None = None self._writer: csv.DictWriter | None = None self._header_written = False @@ -114,8 +118,9 @@ async def has(self, row_id: str) -> bool: async def write(self, row: dict[str, Any]) -> None: """Write a single row to the CSV file. - On first write, opens the file and writes the header row. - Subsequent writes append rows to the file. + On first write, opens the file. If truncate=True, overwrites any existing + file and writes header. If truncate=False, appends to existing file + (skips header if file exists). Args ---- @@ -124,10 +129,14 @@ async def write(self, row: dict[str, Any]) -> None: if isinstance(self._storage, FileStorage) and self._write_file is None: file_path = self._storage.get_path(self._file_key) file_path.parent.mkdir(parents=True, exist_ok=True) - self._write_file = Path.open(file_path, "w", encoding="utf-8", newline="") + file_exists = file_path.exists() and file_path.stat().st_size > 0 + mode = "w" if self._truncate else "a" + write_header = self._truncate or not file_exists + self._write_file = Path.open(file_path, mode, encoding="utf-8", newline="") self._writer = csv.DictWriter(self._write_file, fieldnames=list(row.keys())) - self._writer.writeheader() - self._header_written = True + if write_header: + self._writer.writeheader() + self._header_written = write_header if self._writer is not None: self._writer.writerow(row) diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py index a8d97df2b3..f2ef8adebc 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py @@ -116,7 +116,12 @@ def list(self) -> list[str]: ] def open( - self, table_name: str, transformer: RowTransformer | None = None + self, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, ) -> CSVTable: """Open table for streaming.""" - return CSVTable(self._storage, table_name, transformer=transformer) + return CSVTable( + self._storage, table_name, transformer=transformer, truncate=truncate + ) diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py index b743db4296..2c62713987 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py @@ -52,6 +52,7 @@ def __init__( storage: Storage, table_name: str, transformer: RowTransformer | None = None, + truncate: bool = True, ): """Initialize with storage backend and table name. @@ -61,11 +62,14 @@ def __init__( transformer: Optional callable to transform each row before yielding. Receives a dict, returns a transformed dict. Defaults to identity (no transformation). + truncate: If True (default), overwrite file on close. + If False, append to existing file. """ self._storage = storage self._table_name = table_name self._file_key = f"{table_name}.parquet" self._transformer = transformer or _identity + self._truncate = truncate self._df: pd.DataFrame | None = None self._write_rows: list[dict[str, Any]] = [] @@ -131,11 +135,16 @@ async def close(self) -> None: """Flush accumulated rows to Parquet file and release resources. Converts all accumulated rows to a DataFrame and writes - to storage as a Parquet file. + to storage as a Parquet file. If truncate=False and file exists, + appends to existing data. """ if self._write_rows: - df = pd.DataFrame(self._write_rows) - await self._storage.set(self._file_key, df.to_parquet()) + new_df = pd.DataFrame(self._write_rows) + if not self._truncate and await self._storage.has(self._file_key): + existing_data = await self._storage.get(self._file_key, as_bytes=True) + existing_df = pd.read_parquet(BytesIO(existing_data)) + new_df = pd.concat([existing_df, new_df], ignore_index=True) + await self._storage.set(self._file_key, new_df.to_parquet()) self._write_rows = [] self._df = None diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py index 3f8d3b518d..b6c6f251bc 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -109,7 +109,12 @@ def list(self) -> list[str]: for file in self._storage.find(re.compile(r"\.parquet$")) ] - def open(self, table_name: str, transformer: RowTransformer | None = None) -> Table: + def open( + self, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, + ) -> Table: """Open a table for streaming row operations. Returns a ParquetTable that simulates streaming by loading the @@ -121,10 +126,13 @@ def open(self, table_name: str, transformer: RowTransformer | None = None) -> Ta The name of the table to open. transformer: RowTransformer | None Optional callable to transform each row on read. + truncate: bool + If True (default), overwrite existing file on close. + If False, append new rows to existing file. Returns ------- Table: A ParquetTable instance for row-by-row access. """ - return ParquetTable(self._storage, table_name, transformer) + return ParquetTable(self._storage, table_name, transformer, truncate=truncate) diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py index 74bb049937..39965839f8 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py @@ -78,7 +78,10 @@ def list(self) -> list[str]: @abstractmethod def open( - self, table_name: str, transformer: RowTransformer | None = None + self, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, ) -> Table: # Returns Table instance """Open a table for row-by-row streaming operations. @@ -88,6 +91,9 @@ def open( The name of the table to open. transformer: RowTransformer | None Optional transformer function to apply to each row. + truncate: bool + If True (default), truncate existing file on first write. + If False, append rows to existing file (DB-like behavior). Returns ------- diff --git a/packages/graphrag/graphrag/index/workflows/load_input_documents.py b/packages/graphrag/graphrag/index/workflows/load_input_documents.py index 26166bb279..a2d7ba937e 100644 --- a/packages/graphrag/graphrag/index/workflows/load_input_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_input_documents.py @@ -8,6 +8,7 @@ import pandas as pd from graphrag_input import InputReader, create_input_reader +from graphrag_storage.tables.table import Table from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.typing.context import PipelineRunContext @@ -23,26 +24,37 @@ async def run_workflow( """Load and parse input documents into a standard format.""" input_reader = create_input_reader(config.input, context.input_storage) - output = await load_input_documents(input_reader) + async with ( + context.output_table_provider.open("documents") as documents_table, + ): + sample, total_count = await load_input_documents(input_reader, documents_table) - if len(output) == 0: - msg = "Error reading documents, please see logs." - logger.error(msg) - raise ValueError(msg) + if total_count == 0: + msg = "Error reading documents, please see logs." + logger.error(msg) + raise ValueError(msg) - logger.info("Final # of rows loaded: %s", len(output)) - context.stats.num_documents = len(output) + logger.info("Final # of rows loaded: %s", total_count) + context.stats.num_documents = total_count - await context.output_table_provider.write_dataframe("documents", output) + return WorkflowFunctionOutput(result=sample) - return WorkflowFunctionOutput(result=output) - -async def load_input_documents(input_reader: InputReader) -> pd.DataFrame: +async def load_input_documents( + input_reader: InputReader, documents_table: Table, sample_size: int = 5 +) -> tuple[pd.DataFrame, int]: """Load and parse input documents into a standard format.""" - documents = [asdict(doc) async for doc in input_reader] - documents = pd.DataFrame(documents) - documents["human_readable_id"] = documents.index - if "raw_data" not in documents.columns: - documents["raw_data"] = pd.Series(dtype="object") - return documents + sample: list[dict] = [] + idx = 0 + + async for doc in input_reader: + row = asdict(doc) + row["human_readable_id"] = idx + if "raw_data" not in row: + row["raw_data"] = None + await documents_table.write(row) + if len(sample) < sample_size: + sample.append(row) + idx += 1 + + return pd.DataFrame(sample), idx From 3e80e5b2f9f57bcd145a84182d78a11a75e2e08b Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Thu, 12 Feb 2026 00:25:14 +0000 Subject: [PATCH 07/10] add semver --- .semversioner/next-release/patch-20260212002508389038.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .semversioner/next-release/patch-20260212002508389038.json diff --git a/.semversioner/next-release/patch-20260212002508389038.json b/.semversioner/next-release/patch-20260212002508389038.json new file mode 100644 index 0000000000..64c07d115f --- /dev/null +++ b/.semversioner/next-release/patch-20260212002508389038.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "add streamming to the two first workflows" +} From 877f5ab298a7317891618090aeae0247af307a55 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Thu, 12 Feb 2026 20:32:23 +0000 Subject: [PATCH 08/10] addd csv output reader fix --- packages/graphrag-input/graphrag_input/csv.py | 5 +++++ .../graphrag_storage/tables/csv_table.py | 8 ++++++-- .../graphrag_storage/tables/csv_table_provider.py | 13 +++++++++++-- .../index/workflows/generate_text_embeddings.py | 6 +++--- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/packages/graphrag-input/graphrag_input/csv.py b/packages/graphrag-input/graphrag_input/csv.py index 6c0f51dd3a..370aacca00 100644 --- a/packages/graphrag-input/graphrag_input/csv.py +++ b/packages/graphrag-input/graphrag_input/csv.py @@ -5,12 +5,17 @@ import csv import logging +import sys from graphrag_input.structured_file_reader import StructuredFileReader from graphrag_input.text_document import TextDocument logger = logging.getLogger(__name__) +try: + csv.field_size_limit(sys.maxsize) +except OverflowError: + csv.field_size_limit(100 * 1024 * 1024) class CSVFileReader(StructuredFileReader): """Reader implementation for csv files.""" diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py index d179688c21..1a63a5f817 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py @@ -47,6 +47,7 @@ def __init__( table_name: str, transformer: RowTransformer | None = None, truncate: bool = True, + encoding: str = "utf-8", ): """Initialize with storage backend and table name. @@ -58,12 +59,15 @@ def __init__( Defaults to identity (no transformation). truncate: If True (default), truncate file on first write. If False, append to existing file. + encoding: Character encoding for reading/writing CSV files. + Defaults to "utf-8". """ self._storage = storage self._table_name = table_name self._file_key = f"{table_name}.csv" self._transformer = transformer or _identity self._truncate = truncate + self._encoding = encoding self._write_file: TextIOWrapper | None = None self._writer: csv.DictWriter | None = None self._header_written = False @@ -85,7 +89,7 @@ async def _aiter_impl(self) -> AsyncIterator[Any]: """Implement async iteration over rows.""" if isinstance(self._storage, FileStorage): file_path = self._storage.get_path(self._file_key) - with Path.open(file_path, "r", encoding="utf-8") as f: + with Path.open(file_path, "r", encoding=self._encoding) as f: reader = csv.DictReader(f) for row in reader: yield _apply_transformer(self._transformer, row) @@ -132,7 +136,7 @@ async def write(self, row: dict[str, Any]) -> None: file_exists = file_path.exists() and file_path.stat().st_size > 0 mode = "w" if self._truncate else "a" write_header = self._truncate or not file_exists - self._write_file = Path.open(file_path, mode, encoding="utf-8", newline="") + self._write_file = Path.open(file_path, mode, encoding=self._encoding, newline="") self._writer = csv.DictWriter(self._write_file, fieldnames=list(row.keys())) if write_header: self._writer.writeheader() diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py index f2ef8adebc..6d33dea777 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py @@ -120,8 +120,17 @@ def open( table_name: str, transformer: RowTransformer | None = None, truncate: bool = True, + encoding: str = "utf-8", ) -> CSVTable: - """Open table for streaming.""" + """Open table for streaming. + + Args: + table_name: Name of the table to open + transformer: Optional callable to transform each row + truncate: If True, truncate file on first write + encoding: Character encoding for reading/writing CSV files. + Defaults to "utf-8". + """ return CSVTable( - self._storage, table_name, transformer=transformer, truncate=truncate + self._storage, table_name, transformer=transformer, truncate=truncate, encoding=encoding ) diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index ebce58b914..69e30e9b65 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -105,13 +105,13 @@ async def generate_text_embeddings( """All the steps to generate all embeddings.""" embedding_param_map = { text_unit_text_embedding: { - "data": text_units.loc[:, ["id", "text"]] + "data": text_units.loc[:, ["id", "text"]].fillna("") if text_units is not None else None, "embed_column": "text", }, entity_description_embedding: { - "data": entities.loc[:, ["id", "title", "description"]].assign( + "data": entities.loc[:, ["id", "title", "description"]].fillna("").assign( title_description=lambda df: df["title"] + ":" + df["description"] ) if entities is not None @@ -119,7 +119,7 @@ async def generate_text_embeddings( "embed_column": "title_description", }, community_full_content_embedding: { - "data": community_reports.loc[:, ["id", "full_content"]] + "data": community_reports.loc[:, ["id", "full_content"]].fillna("") if community_reports is not None else None, "embed_column": "full_content", From 981a51b90d489052970522bf183910f812c6f2a9 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Thu, 12 Feb 2026 20:34:32 +0000 Subject: [PATCH 09/10] run formatter --- packages/graphrag-input/graphrag_input/csv.py | 1 + .../graphrag-storage/graphrag_storage/tables/csv_table.py | 4 +++- .../graphrag_storage/tables/csv_table_provider.py | 8 ++++++-- .../graphrag/index/workflows/generate_text_embeddings.py | 7 ++++--- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/packages/graphrag-input/graphrag_input/csv.py b/packages/graphrag-input/graphrag_input/csv.py index 370aacca00..e041bff275 100644 --- a/packages/graphrag-input/graphrag_input/csv.py +++ b/packages/graphrag-input/graphrag_input/csv.py @@ -17,6 +17,7 @@ except OverflowError: csv.field_size_limit(100 * 1024 * 1024) + class CSVFileReader(StructuredFileReader): """Reader implementation for csv files.""" diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py index 1a63a5f817..04c6fb8756 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py @@ -136,7 +136,9 @@ async def write(self, row: dict[str, Any]) -> None: file_exists = file_path.exists() and file_path.stat().st_size > 0 mode = "w" if self._truncate else "a" write_header = self._truncate or not file_exists - self._write_file = Path.open(file_path, mode, encoding=self._encoding, newline="") + self._write_file = Path.open( + file_path, mode, encoding=self._encoding, newline="" + ) self._writer = csv.DictWriter(self._write_file, fieldnames=list(row.keys())) if write_header: self._writer.writeheader() diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py index 6d33dea777..2561bde0d8 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py @@ -123,7 +123,7 @@ def open( encoding: str = "utf-8", ) -> CSVTable: """Open table for streaming. - + Args: table_name: Name of the table to open transformer: Optional callable to transform each row @@ -132,5 +132,9 @@ def open( Defaults to "utf-8". """ return CSVTable( - self._storage, table_name, transformer=transformer, truncate=truncate, encoding=encoding + self._storage, + table_name, + transformer=transformer, + truncate=truncate, + encoding=encoding, ) diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index 69e30e9b65..a78de2bb58 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -111,9 +111,10 @@ async def generate_text_embeddings( "embed_column": "text", }, entity_description_embedding: { - "data": entities.loc[:, ["id", "title", "description"]].fillna("").assign( - title_description=lambda df: df["title"] + ":" + df["description"] - ) + "data": entities + .loc[:, ["id", "title", "description"]] + .fillna("") + .assign(title_description=lambda df: df["title"] + ":" + df["description"]) if entities is not None else None, "embed_column": "title_description", From 657b789b397bb9af546024750697611898d3474a Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Thu, 12 Feb 2026 20:45:18 +0000 Subject: [PATCH 10/10] add csv config --- .../graphrag-storage/graphrag_storage/tables/csv_table.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py index 04c6fb8756..0c55b17ea3 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py @@ -7,6 +7,7 @@ import csv import inspect +import sys from pathlib import Path from typing import TYPE_CHECKING, Any @@ -21,6 +22,11 @@ from graphrag_storage import Storage +try: + csv.field_size_limit(sys.maxsize) +except OverflowError: + csv.field_size_limit(100 * 1024 * 1024) + def _identity(row: dict[str, Any]) -> Any: """Return row unchanged (default transformer)."""