Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260206214141420353.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add CSVTableProvider"
}
19 changes: 19 additions & 0 deletions packages/graphrag-storage/graphrag_storage/memory_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

"""In-memory storage implementation."""

import re
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any

from graphrag_storage.file_storage import FileStorage
Expand Down Expand Up @@ -81,3 +83,20 @@ def child(self, name: str | None) -> "Storage":
def keys(self) -> list[str]:
"""Return the keys in the storage."""
return list(self._storage.keys())

def find(self, file_pattern: re.Pattern[str]) -> Iterator[str]:
"""Find keys in memory storage matching the given pattern.

Args
----
file_pattern: re.Pattern[str]
Regular expression pattern to match against keys.

Yields
------
str:
Keys that match the pattern.
"""
for key in self._storage:
if file_pattern.search(key):
yield key
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""CSV-based table provider implementation."""

import logging
import re
from io import StringIO

import pandas as pd

from graphrag_storage.storage import Storage
from graphrag_storage.tables.table_provider import TableProvider

logger = logging.getLogger(__name__)


class CSVTableProvider(TableProvider):
"""Table provider that stores tables as CSV files using an underlying Storage instance.

This provider converts between pandas DataFrames and csv format,
storing the data through a Storage backend (file, blob, cosmos, etc.).
"""

def __init__(self, storage: Storage, **kwargs) -> None:
"""Initialize the CSV table provider with an underlying storage instance.

Args
----
storage: Storage
The storage instance to use for reading and writing csv files.
**kwargs: Any
Additional keyword arguments (currently unused).
"""
self._storage = storage

async def read_dataframe(self, table_name: str) -> pd.DataFrame:
"""Read a table from storage as a pandas DataFrame.

Args
----
table_name: str
The name of the table to read. The file will be accessed as '{table_name}.csv'.

Returns
-------
pd.DataFrame:
The table data loaded from the csv file.

Raises
------
ValueError:
If the table file does not exist in storage.
Exception:
If there is an error reading or parsing the csv file.
"""
filename = f"{table_name}.csv"
if not await self._storage.has(filename):
msg = f"Could not find {filename} in storage!"
raise ValueError(msg)
try:
logger.info("reading table from storage: %s", filename)
csv_data = await self._storage.get(filename, as_bytes=False)
# Handle empty CSV (pandas can't parse files with no columns)
if not csv_data or csv_data.strip() == "":
return pd.DataFrame()
return pd.read_csv(StringIO(csv_data))
except Exception:
logger.exception("error loading table from storage: %s", filename)
raise

async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None:
"""Write a pandas DataFrame to storage as a CSV file.

Args
----
table_name: str
The name of the table to write. The file will be saved as '{table_name}.csv'.
df: pd.DataFrame
The DataFrame to write to storage.
"""
await self._storage.set(f"{table_name}.csv", df.to_csv(index=False))

async def has(self, table_name: str) -> bool:
"""Check if a table exists in storage.

Args
----
table_name: str
The name of the table to check.

Returns
-------
bool:
True if the table exists, False otherwise.
"""
return await self._storage.has(f"{table_name}.csv")

def list(self) -> list[str]:
"""List all table names in storage.

Returns
-------
list[str]:
List of table names (without .csv extension).
"""
return [
file.replace(".csv", "")
for file in self._storage.find(re.compile(r"\.csv$"))
]
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None:
"""
await self._storage.set(f"{table_name}.parquet", df.to_parquet())

async def has_dataframe(self, table_name: str) -> bool:
async def has(self, table_name: str) -> bool:
"""Check if a table exists in storage.

Args
Expand All @@ -94,8 +94,8 @@ async def has_dataframe(self, table_name: str) -> bool:
"""
return await self._storage.has(f"{table_name}.parquet")

def find_tables(self) -> list[str]:
"""Find all table names in storage.
def list(self) -> list[str]:
"""List all table names in storage.

Returns
-------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None:
"""

@abstractmethod
async def has_dataframe(self, table_name: str) -> bool:
async def has(self, table_name: str) -> bool:
"""Check if a table exists in the provider.

Args
Expand All @@ -65,8 +65,8 @@ async def has_dataframe(self, table_name: str) -> bool:
"""

@abstractmethod
def find_tables(self) -> list[str]:
"""Find all table names in the provider.
def list(self) -> list[str]:
"""List all table names in the provider.

Returns
-------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def create_table_provider(
)

register_table_provider(TableType.Parquet, ParquetTableProvider)
case TableType.CSV:
from graphrag_storage.tables.csv_table_provider import (
CSVTableProvider,
)

register_table_provider(TableType.CSV, CSVTableProvider)
case _:
msg = f"TableProviderConfig.type '{table_type}' is not registered in the TableProviderFactory. Registered types: {', '.join(table_provider_factory.keys())}."
raise ValueError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class TableType(StrEnum):
"""Enum for table storage types."""

Parquet = "parquet"
CSV = "csv"
2 changes: 1 addition & 1 deletion packages/graphrag/graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _resolve_output_files(
# for optional output files, set the dict entry to None instead of erroring out if it does not exist
if optional_list:
for optional_file in optional_list:
file_exists = asyncio.run(table_provider.has_dataframe(optional_file))
file_exists = asyncio.run(table_provider.has(optional_file))
if file_exists:
df_value = asyncio.run(table_provider.read_dataframe(optional_file))
dataframe_dict[optional_file] = df_value
Expand Down
2 changes: 1 addition & 1 deletion packages/graphrag/graphrag/index/run/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,6 @@ async def _copy_previous_output(
previous_table_provider: TableProvider,
) -> None:
"""Copy all parquet tables from output to previous storage for backup."""
for table_name in output_table_provider.find_tables():
for table_name in output_table_provider.list():
table = await output_table_provider.read_dataframe(table_name)
await previous_table_provider.write_dataframe(table_name, table)
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ async def run_workflow(
communities = await context.output_table_provider.read_dataframe("communities")

claims = None
if (
config.extract_claims.enabled
and await context.output_table_provider.has_dataframe("covariates")
if config.extract_claims.enabled and await context.output_table_provider.has(
"covariates"
):
claims = await context.output_table_provider.read_dataframe("covariates")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ async def run_workflow(
)

final_covariates = None
if (
config.extract_claims.enabled
and await context.output_table_provider.has_dataframe("covariates")
if config.extract_claims.enabled and await context.output_table_provider.has(
"covariates"
):
final_covariates = await context.output_table_provider.read_dataframe(
"covariates"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ async def run_workflow(
get_update_table_providers(config, context.state["update_timestamp"])
)

if await previous_table_provider.has_dataframe(
if await previous_table_provider.has(
"covariates"
) and await delta_table_provider.has_dataframe("covariates"):
) and await delta_table_provider.has("covariates"):
logger.info("Updating Covariates")
await _update_covariates(
previous_table_provider, delta_table_provider, output_table_provider
Expand Down
119 changes: 119 additions & 0 deletions tests/unit/storage/test_csv_table_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import unittest
from io import StringIO

import pandas as pd
import pytest
from graphrag_storage import (
StorageConfig,
StorageType,
create_storage,
)
from graphrag_storage.tables.csv_table_provider import CSVTableProvider


class TestCSVTableProvider(unittest.IsolatedAsyncioTestCase):
"""Test suite for CSVTableProvider."""

def setUp(self):
"""Set up test fixtures."""
self.storage = create_storage(
StorageConfig(
type=StorageType.Memory,
)
)
self.table_provider = CSVTableProvider(storage=self.storage)

async def asyncTearDown(self):
"""Clean up after tests."""
await self.storage.clear()

async def test_write_and_read(self):
"""Test writing and reading a DataFrame."""
df = pd.DataFrame({
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"age": [30, 25, 35],
})

await self.table_provider.write_dataframe("users", df)
result = await self.table_provider.read_dataframe("users")

pd.testing.assert_frame_equal(result, df)

async def test_read_nonexistent_table_raises_error(self):
"""Test that reading a nonexistent table raises ValueError."""
with pytest.raises(
ValueError, match=r"Could not find nonexistent\.csv in storage!"
):
await self.table_provider.read_dataframe("nonexistent")

async def test_empty_dataframe(self):
"""Test writing and reading an empty DataFrame."""
df = pd.DataFrame()

await self.table_provider.write_dataframe("empty", df)
result = await self.table_provider.read_dataframe("empty")

pd.testing.assert_frame_equal(result, df)

async def test_dataframe_with_multiple_types(self):
"""Test DataFrame with multiple column types."""
df = pd.DataFrame({
"int_col": [1, 2, 3],
"float_col": [1.1, 2.2, 3.3],
"str_col": ["a", "b", "c"],
"bool_col": [True, False, True],
})

await self.table_provider.write_dataframe("mixed", df)
result = await self.table_provider.read_dataframe("mixed")

pd.testing.assert_frame_equal(result, df)

async def test_storage_persistence(self):
"""Test that data is persisted in underlying storage."""
df = pd.DataFrame({"x": [1, 2, 3]})

await self.table_provider.write_dataframe("test", df)

assert await self.storage.has("test.csv")

csv_data = await self.storage.get("test.csv", as_bytes=False)
loaded_df = pd.read_csv(StringIO(csv_data))

pd.testing.assert_frame_equal(loaded_df, df)

async def test_has(self):
"""Test has() method for checking table existence."""
df = pd.DataFrame({"a": [1, 2, 3]})

# Table doesn't exist yet
assert not await self.table_provider.has("test_table")

# Write the table
await self.table_provider.write_dataframe("test_table", df)

# Now it exists
assert await self.table_provider.has("test_table")

async def test_list(self):
"""Test listing all tables in storage."""
# Initially empty
assert self.table_provider.list() == []

# Create some tables
df1 = pd.DataFrame({"a": [1, 2, 3]})
df2 = pd.DataFrame({"b": [4, 5, 6]})
df3 = pd.DataFrame({"c": [7, 8, 9]})

await self.table_provider.write_dataframe("table1", df1)
await self.table_provider.write_dataframe("table2", df2)
await self.table_provider.write_dataframe("table3", df3)

# List tables
tables = self.table_provider.list()
assert len(tables) == 3
assert set(tables) == {"table1", "table2", "table3"}
6 changes: 3 additions & 3 deletions tests/unit/storage/test_parquet_table_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ async def test_storage_persistence(self):

pd.testing.assert_frame_equal(loaded_df, df)

async def test_has_dataframe(self):
async def test_has(self):
df = pd.DataFrame({"a": [1, 2, 3]})

# Table doesn't exist yet
assert not await self.table_provider.has_dataframe("test_table")
assert not await self.table_provider.has("test_table")

# Write the table
await self.table_provider.write_dataframe("test_table", df)

# Now it exists
assert await self.table_provider.has_dataframe("test_table")
assert await self.table_provider.has("test_table")