Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260211214912747264.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add async iterator support to InputReader and use it in load_input_documents and load_update_documents workflows."
}
34 changes: 23 additions & 11 deletions packages/graphrag-input/graphrag_input/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import AsyncIterator

from graphrag_storage import Storage

from graphrag_input.text_document import TextDocument
Expand All @@ -33,34 +35,44 @@ def __init__(
self._file_pattern = file_pattern

async def read_files(self) -> list[TextDocument]:
"""Load files from storage and apply a loader function based on file type. Process metadata on the results if needed."""
"""Load all files from storage and return them as a single list."""
return [doc async for doc in self]

def __aiter__(self) -> AsyncIterator[TextDocument]:
"""Return the async iterator, enabling `async for doc in reader`."""
return self._iterate_files()

async def _iterate_files(self) -> AsyncIterator[TextDocument]:
"""Async generator that yields documents one at a time as files are loaded."""
files = list(self._storage.find(re.compile(self._file_pattern)))
if len(files) == 0:
msg = f"No {self._file_pattern} matches found in storage"
logger.warning(msg)
files = []
return

documents: list[TextDocument] = []
file_count = len(files)
doc_count = 0

for file in files:
try:
documents.extend(await self.read_file(file))
for doc in await self.read_file(file):
doc_count += 1
yield doc
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
logger.warning("Warning! Error loading file %s. Skipping...", file)
logger.warning("Error: %s", e)

logger.info(
"Found %d %s files, loading %d",
len(files),
file_count,
self._file_pattern,
len(documents),
doc_count,
)
total_files_log = (
f"Total number of unfiltered {self._file_pattern} rows: {len(documents)}"
logger.info(
"Total number of unfiltered %s rows: %d",
self._file_pattern,
doc_count,
)
logger.info(total_files_log)

return documents

@abstractmethod
async def read_file(self, path: str) -> list[TextDocument]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""A module containing run_workflow method definition."""

import logging
from dataclasses import asdict

import pandas as pd
from graphrag_input import InputReader, create_input_reader
Expand Down Expand Up @@ -39,8 +40,9 @@ async def run_workflow(

async def load_input_documents(input_reader: InputReader) -> pd.DataFrame:
"""Load and parse input documents into a standard format."""
output = pd.DataFrame(await input_reader.read_files())
output["human_readable_id"] = output.index
if "raw_data" not in output.columns:
output["raw_data"] = pd.Series(dtype="object")
return output
documents = [asdict(doc) async for doc in input_reader]
documents = pd.DataFrame(documents)
documents["human_readable_id"] = documents.index
if "raw_data" not in documents.columns:
documents["raw_data"] = pd.Series(dtype="object")
return documents
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""A module containing run_workflow method definition."""

import logging
from dataclasses import asdict

import pandas as pd
from graphrag_input.input_reader import InputReader
Expand Down Expand Up @@ -50,7 +51,8 @@ async def load_update_documents(
previous_table_provider: TableProvider,
) -> pd.DataFrame:
"""Load and parse update-only input documents into a standard format."""
input_documents = pd.DataFrame(await input_reader.read_files())
input_documents = [asdict(doc) async for doc in input_reader]
input_documents = pd.DataFrame(input_documents)
input_documents["human_readable_id"] = input_documents.index
if "raw_data" not in input_documents.columns:
input_documents["raw_data"] = pd.Series(dtype="object")
Expand Down