diff --git a/.semversioner/next-release/patch-20260206205026841660.json b/.semversioner/next-release/patch-20260206205026841660.json new file mode 100644 index 0000000000..db0635dd4c --- /dev/null +++ b/.semversioner/next-release/patch-20260206205026841660.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add table provider factory." +} diff --git a/packages/graphrag-cache/graphrag_cache/cache_factory.py b/packages/graphrag-cache/graphrag_cache/cache_factory.py index 6b1310754c..33a51099f0 100644 --- a/packages/graphrag-cache/graphrag_cache/cache_factory.py +++ b/packages/graphrag-cache/graphrag_cache/cache_factory.py @@ -5,20 +5,14 @@ """Cache factory implementation.""" from collections.abc import Callable -from typing import TYPE_CHECKING -from graphrag_common.factory import Factory -from graphrag_storage import create_storage +from graphrag_common.factory import Factory, ServiceScope +from graphrag_storage import Storage, create_storage +from graphrag_cache.cache import Cache from graphrag_cache.cache_config import CacheConfig from graphrag_cache.cache_type import CacheType -if TYPE_CHECKING: - from graphrag_common.factory import ServiceScope - from graphrag_storage import Storage - - from graphrag_cache.cache import Cache - class CacheFactory(Factory["Cache"]): """A factory class for cache implementations.""" @@ -29,8 +23,8 @@ class CacheFactory(Factory["Cache"]): def register_cache( cache_type: str, - cache_initializer: Callable[..., "Cache"], - scope: "ServiceScope" = "transient", + cache_initializer: Callable[..., Cache], + scope: ServiceScope = "transient", ) -> None: """Register a custom cache implementation. @@ -45,7 +39,7 @@ def register_cache( def create_cache( - config: CacheConfig | None = None, storage: "Storage | None" = None + config: CacheConfig | None = None, storage: Storage | None = None ) -> "Cache": """Create a cache implementation based on the given configuration. diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider_config.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider_config.py new file mode 100644 index 0000000000..1255646c28 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider_config.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Storage configuration model.""" + +from pydantic import BaseModel, ConfigDict, Field + +from graphrag_storage.tables.table_type import TableType + + +class TableProviderConfig(BaseModel): + """The default configuration section for table providers.""" + + model_config = ConfigDict(extra="allow") + """Allow extra fields to support custom table provider implementations.""" + + type: str = Field( + description="The table type to use.", + default=TableType.Parquet, + ) diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py new file mode 100644 index 0000000000..93add5d8d1 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider_factory.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + + +"""Storage factory implementation.""" + +from collections.abc import Callable + +from graphrag_common.factory import Factory, ServiceScope + +from graphrag_storage.storage import Storage +from graphrag_storage.tables.table_provider import TableProvider +from graphrag_storage.tables.table_provider_config import TableProviderConfig +from graphrag_storage.tables.table_type import TableType + + +class TableProviderFactory(Factory[TableProvider]): + """A factory class for table storage implementations.""" + + +table_provider_factory = TableProviderFactory() + + +def register_table_provider( + table_type: str, + table_initializer: Callable[..., TableProvider], + scope: ServiceScope = "transient", +) -> None: + """Register a custom storage implementation. + + Args + ---- + - table_type: str + The table type id to register. + - table_initializer: Callable[..., TableProvider] + The table initializer to register. + """ + table_provider_factory.register(table_type, table_initializer, scope) + + +def create_table_provider( + config: TableProviderConfig, storage: Storage | None = None +) -> TableProvider: + """Create a table provider implementation based on the given configuration. + + Args + ---- + - config: TableProviderConfig + The table provider configuration to use. + - storage: Storage | None + The storage implementation to use for file-based TableProviders such as Parquet and CSV. + + Returns + ------- + TableProvider + The created table provider implementation. + """ + config_model = config.model_dump() + table_type = config.type + + if table_type not in table_provider_factory: + match table_type: + case TableType.Parquet: + from graphrag_storage.tables.parquet_table_provider import ( + ParquetTableProvider, + ) + + register_table_provider(TableType.Parquet, ParquetTableProvider) + case _: + msg = f"TableProviderConfig.type '{table_type}' is not registered in the TableProviderFactory. Registered types: {', '.join(table_provider_factory.keys())}." + raise ValueError(msg) + + if storage: + config_model["storage"] = storage + + return table_provider_factory.create(table_type, config_model) diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_type.py b/packages/graphrag-storage/graphrag_storage/tables/table_type.py new file mode 100644 index 0000000000..ab8cdf7015 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/table_type.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + + +"""Builtin table storage implementation types.""" + +from enum import StrEnum + + +class TableType(StrEnum): + """Enum for table storage types.""" + + Parquet = "parquet" diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index 1f808420d4..d3f7109733 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any from graphrag_storage import create_storage -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider_factory import create_table_provider import graphrag.api as api from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks @@ -378,7 +378,7 @@ def _resolve_output_files( """Read indexing output files to a dataframe dict.""" dataframe_dict = {} storage_obj = create_storage(config.output_storage) - table_provider = ParquetTableProvider(storage_obj) + table_provider = create_table_provider(config.table_provider, storage=storage_obj) for name in output_list: df_value = asyncio.run(table_provider.read_dataframe(name)) dataframe_dict[name] = df_value diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 84fb2de884..dc28da97ca 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -12,6 +12,7 @@ from graphrag_input import InputConfig from graphrag_llm.config import ModelConfig from graphrag_storage import StorageConfig, StorageType +from graphrag_storage.tables.table_provider_config import TableProviderConfig from graphrag_vectors import IndexSchema, VectorStoreConfig, VectorStoreType from pydantic import BaseModel, Field, model_validator @@ -138,6 +139,11 @@ def _validate_update_output_storage_base_dir(self) -> None: Path(self.update_output_storage.base_dir).resolve() ) + table_provider: TableProviderConfig = Field( + description="The table provider configuration.", default=TableProviderConfig() + ) + """The table provider configuration. By default we read/write parquet to disk. You can register custom output table storage.""" + cache: CacheConfig = Field( description="The cache configuration.", default=CacheConfig(**asdict(graphrag_config_defaults.cache)), diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index 24ff39cc07..a76b161d36 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -13,8 +13,8 @@ import pandas as pd from graphrag_cache import create_cache from graphrag_storage import create_storage -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag_storage.tables.table_provider import TableProvider +from graphrag_storage.tables.table_provider_factory import create_table_provider from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -36,9 +36,10 @@ async def run_pipeline( ) -> AsyncIterable[PipelineRunResult]: """Run all workflows using a simplified pipeline.""" input_storage = create_storage(config.input_storage) - input_table_provider = ParquetTableProvider(input_storage) output_storage = create_storage(config.output_storage) + output_table_provider = create_table_provider(config.table_provider, output_storage) + cache = create_cache(config.cache) # load existing state in case any workflows are stateful @@ -56,13 +57,16 @@ async def run_pipeline( update_timestamp = time.strftime("%Y%m%d-%H%M%S") timestamped_storage = update_storage.child(update_timestamp) delta_storage = timestamped_storage.child("delta") - delta_table_provider = ParquetTableProvider(delta_storage) + delta_table_provider = create_table_provider( + config.table_provider, delta_storage + ) # copy the previous output to a backup folder, so we can replace it with the update # we'll read from this later when we merge the old and new indexes previous_storage = timestamped_storage.child("previous") - previous_table_provider = ParquetTableProvider(previous_storage) + previous_table_provider = create_table_provider( + config.table_provider, previous_storage + ) - output_table_provider = ParquetTableProvider(output_storage) await _copy_previous_output(output_table_provider, previous_table_provider) state["update_timestamp"] = update_timestamp @@ -74,7 +78,6 @@ async def run_pipeline( context = create_run_context( input_storage=input_storage, - input_table_provider=input_table_provider, output_storage=delta_storage, output_table_provider=delta_table_provider, previous_table_provider=previous_table_provider, @@ -88,15 +91,13 @@ async def run_pipeline( # if the user passes in a df directly, write directly to storage so we can skip finding/parsing later if input_documents is not None: - output_table_provider = ParquetTableProvider(output_storage) await output_table_provider.write_dataframe("documents", input_documents) pipeline.remove("load_input_documents") context = create_run_context( input_storage=input_storage, - input_table_provider=input_table_provider, output_storage=output_storage, - output_table_provider=ParquetTableProvider(storage=output_storage), + output_table_provider=output_table_provider, cache=cache, callbacks=callbacks, state=state, diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index 207e9561a0..b85459bf74 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -8,6 +8,8 @@ from graphrag_storage import Storage, create_storage from graphrag_storage.memory_storage import MemoryStorage from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider +from graphrag_storage.tables.table_provider_factory import create_table_provider from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -20,10 +22,9 @@ def create_run_context( input_storage: Storage | None = None, - input_table_provider: ParquetTableProvider | None = None, output_storage: Storage | None = None, - output_table_provider: ParquetTableProvider | None = None, - previous_table_provider: ParquetTableProvider | None = None, + output_table_provider: TableProvider | None = None, + previous_table_provider: TableProvider | None = None, cache: Cache | None = None, callbacks: WorkflowCallbacks | None = None, stats: PipelineRunStats | None = None, @@ -34,8 +35,6 @@ def create_run_context( output_storage = output_storage or MemoryStorage() return PipelineRunContext( input_storage=input_storage, - input_table_provider=input_table_provider - or ParquetTableProvider(storage=input_storage), output_storage=output_storage, output_table_provider=output_table_provider or ParquetTableProvider(storage=output_storage), @@ -59,7 +58,7 @@ def create_callback_chain( def get_update_table_providers( config: GraphRagConfig, timestamp: str -) -> tuple[ParquetTableProvider, ParquetTableProvider, ParquetTableProvider]: +) -> tuple[TableProvider, TableProvider, TableProvider]: """Get table providers for the update index run.""" output_storage = create_storage(config.output_storage) update_storage = create_storage(config.update_output_storage) @@ -67,8 +66,10 @@ def get_update_table_providers( delta_storage = timestamped_storage.child("delta") previous_storage = timestamped_storage.child("previous") - output_table_provider = ParquetTableProvider(output_storage) - previous_table_provider = ParquetTableProvider(previous_storage) - delta_table_provider = ParquetTableProvider(delta_storage) + output_table_provider = create_table_provider(config.table_provider, output_storage) + previous_table_provider = create_table_provider( + config.table_provider, previous_storage + ) + delta_table_provider = create_table_provider(config.table_provider, delta_storage) return output_table_provider, previous_table_provider, delta_table_provider diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index f606218dd2..277e41f090 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -20,8 +20,6 @@ class PipelineRunContext: stats: PipelineRunStats input_storage: Storage "Storage for reading input documents." - input_table_provider: TableProvider - "Table provider for reading input tables." output_storage: Storage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." output_table_provider: TableProvider