diff --git a/.semversioner/next-release/patch-20260211214912747264.json b/.semversioner/next-release/patch-20260211214912747264.json new file mode 100644 index 0000000000..e65444da75 --- /dev/null +++ b/.semversioner/next-release/patch-20260211214912747264.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add async iterator support to InputReader and use it in load_input_documents and load_update_documents workflows." +} diff --git a/packages/graphrag-input/graphrag_input/input_reader.py b/packages/graphrag-input/graphrag_input/input_reader.py index be95168336..ae840eb8f2 100644 --- a/packages/graphrag-input/graphrag_input/input_reader.py +++ b/packages/graphrag-input/graphrag_input/input_reader.py @@ -11,6 +11,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import AsyncIterator + from graphrag_storage import Storage from graphrag_input.text_document import TextDocument @@ -33,34 +35,44 @@ def __init__( self._file_pattern = file_pattern async def read_files(self) -> list[TextDocument]: - """Load files from storage and apply a loader function based on file type. Process metadata on the results if needed.""" + """Load all files from storage and return them as a single list.""" + return [doc async for doc in self] + + def __aiter__(self) -> AsyncIterator[TextDocument]: + """Return the async iterator, enabling `async for doc in reader`.""" + return self._iterate_files() + + async def _iterate_files(self) -> AsyncIterator[TextDocument]: + """Async generator that yields documents one at a time as files are loaded.""" files = list(self._storage.find(re.compile(self._file_pattern))) if len(files) == 0: msg = f"No {self._file_pattern} matches found in storage" logger.warning(msg) - files = [] + return - documents: list[TextDocument] = [] + file_count = len(files) + doc_count = 0 for file in files: try: - documents.extend(await self.read_file(file)) + for doc in await self.read_file(file): + doc_count += 1 + yield doc except Exception as e: # noqa: BLE001 (catching Exception is fine here) logger.warning("Warning! Error loading file %s. Skipping...", file) logger.warning("Error: %s", e) logger.info( "Found %d %s files, loading %d", - len(files), + file_count, self._file_pattern, - len(documents), + doc_count, ) - total_files_log = ( - f"Total number of unfiltered {self._file_pattern} rows: {len(documents)}" + logger.info( + "Total number of unfiltered %s rows: %d", + self._file_pattern, + doc_count, ) - logger.info(total_files_log) - - return documents @abstractmethod async def read_file(self, path: str) -> list[TextDocument]: diff --git a/packages/graphrag/graphrag/index/workflows/load_input_documents.py b/packages/graphrag/graphrag/index/workflows/load_input_documents.py index 8e27ed0a2a..26166bb279 100644 --- a/packages/graphrag/graphrag/index/workflows/load_input_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_input_documents.py @@ -4,6 +4,7 @@ """A module containing run_workflow method definition.""" import logging +from dataclasses import asdict import pandas as pd from graphrag_input import InputReader, create_input_reader @@ -39,8 +40,9 @@ async def run_workflow( async def load_input_documents(input_reader: InputReader) -> pd.DataFrame: """Load and parse input documents into a standard format.""" - output = pd.DataFrame(await input_reader.read_files()) - output["human_readable_id"] = output.index - if "raw_data" not in output.columns: - output["raw_data"] = pd.Series(dtype="object") - return output + documents = [asdict(doc) async for doc in input_reader] + documents = pd.DataFrame(documents) + documents["human_readable_id"] = documents.index + if "raw_data" not in documents.columns: + documents["raw_data"] = pd.Series(dtype="object") + return documents diff --git a/packages/graphrag/graphrag/index/workflows/load_update_documents.py b/packages/graphrag/graphrag/index/workflows/load_update_documents.py index a61a228493..38104fdbf5 100644 --- a/packages/graphrag/graphrag/index/workflows/load_update_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_update_documents.py @@ -4,6 +4,7 @@ """A module containing run_workflow method definition.""" import logging +from dataclasses import asdict import pandas as pd from graphrag_input.input_reader import InputReader @@ -50,7 +51,8 @@ async def load_update_documents( previous_table_provider: TableProvider, ) -> pd.DataFrame: """Load and parse update-only input documents into a standard format.""" - input_documents = pd.DataFrame(await input_reader.read_files()) + input_documents = [asdict(doc) async for doc in input_reader] + input_documents = pd.DataFrame(input_documents) input_documents["human_readable_id"] = input_documents.index if "raw_data" not in input_documents.columns: input_documents["raw_data"] = pd.Series(dtype="object")