From 0e19173df3cdcbd8201f99bc0d3a09b961a1932d Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Wed, 22 Oct 2025 10:19:10 -0700 Subject: [PATCH 01/17] Add graphrag-storage. --- docs/config/yaml.md | 6 +- packages/graphrag-storage/README.md | 14 +++ .../graphrag_storage/__init__.py | 15 +++ .../graphrag_storage/azure_blob_storage.py} | 83 ++++++++-------- .../graphrag_storage/azure_cosmos_storage.py} | 50 ++++++---- .../graphrag_storage/file_storage.py} | 53 +++++----- .../graphrag_storage/memory_storage.py} | 16 +-- .../graphrag_storage/py.typed | 0 .../graphrag_storage/storage.py} | 14 ++- .../graphrag_storage/storage_config.py | 41 ++++++++ .../graphrag_storage/storage_factory.py | 68 +++++++++++++ packages/graphrag-storage/pyproject.toml | 47 +++++++++ packages/graphrag/graphrag/cache/factory.py | 12 +-- .../graphrag/cache/json_pipeline_cache.py | 7 +- packages/graphrag/graphrag/config/defaults.py | 5 +- .../graphrag/graphrag/config/init_content.py | 4 +- .../config/models/graph_rag_config.py | 9 +- .../graphrag/config/models/input_config.py | 2 +- .../graphrag/config/models/storage_config.py | 52 ---------- .../graphrag/index/input/input_reader.py | 5 +- .../index/operations/snapshot_graphml.py | 5 +- .../graphrag/index/run/run_pipeline.py | 6 +- packages/graphrag/graphrag/index/run/utils.py | 19 ++-- .../graphrag/graphrag/index/typing/context.py | 8 +- .../index/update/incremental_index.py | 14 ++- .../index/workflows/load_update_documents.py | 4 +- .../index/workflows/update_communities.py | 9 +- .../workflows/update_community_reports.py | 8 +- .../index/workflows/update_covariates.py | 8 +- .../update_entities_relationships.py | 8 +- .../index/workflows/update_text_units.py | 8 +- .../graphrag/graphrag/storage/__init__.py | 4 + packages/graphrag/graphrag/storage/factory.py | 33 ------- packages/graphrag/graphrag/utils/api.py | 13 +-- packages/graphrag/graphrag/utils/storage.py | 11 +-- packages/graphrag/pyproject.toml | 1 + pyproject.toml | 4 + ...peline_storage.py => test_blob_storage.py} | 10 +- .../storage/test_cosmosdb_storage.py | 12 +-- tests/integration/storage/test_factory.py | 97 ++++++++----------- ...peline_storage.py => test_file_storage.py} | 12 +-- tests/smoke/test_fixtures.py | 4 +- tests/unit/config/utils.py | 2 +- .../cache/test_file_pipeline_cache.py | 6 +- tests/unit/indexing/input/test_csv_loader.py | 2 +- tests/unit/indexing/input/test_json_loader.py | 2 +- tests/unit/indexing/input/test_txt_loader.py | 2 +- tests/unit/load_config/fixtures/config.yaml | 10 ++ uv.lock | 24 +++++ 49 files changed, 502 insertions(+), 347 deletions(-) create mode 100644 packages/graphrag-storage/README.md create mode 100644 packages/graphrag-storage/graphrag_storage/__init__.py rename packages/{graphrag/graphrag/storage/blob_pipeline_storage.py => graphrag-storage/graphrag_storage/azure_blob_storage.py} (81%) rename packages/{graphrag/graphrag/storage/cosmosdb_pipeline_storage.py => graphrag-storage/graphrag_storage/azure_cosmos_storage.py} (91%) rename packages/{graphrag/graphrag/storage/file_pipeline_storage.py => graphrag-storage/graphrag_storage/file_storage.py} (72%) rename packages/{graphrag/graphrag/storage/memory_pipeline_storage.py => graphrag-storage/graphrag_storage/memory_storage.py} (82%) create mode 100644 packages/graphrag-storage/graphrag_storage/py.typed rename packages/{graphrag/graphrag/storage/pipeline_storage.py => graphrag-storage/graphrag_storage/storage.py} (88%) create mode 100644 packages/graphrag-storage/graphrag_storage/storage_config.py create mode 100644 packages/graphrag-storage/graphrag_storage/storage_factory.py create mode 100644 packages/graphrag-storage/pyproject.toml delete mode 100644 packages/graphrag/graphrag/config/models/storage_config.py delete mode 100644 packages/graphrag/graphrag/storage/factory.py rename tests/integration/storage/{test_blob_pipeline_storage.py => test_blob_storage.py} (94%) rename tests/integration/storage/{test_file_pipeline_storage.py => test_file_storage.py} (86%) create mode 100644 tests/unit/load_config/fixtures/config.yaml diff --git a/docs/config/yaml.md b/docs/config/yaml.md index ace57e3b1c..941e938a6d 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -81,7 +81,7 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th #### Fields - `storage` **StorageConfig** - - `type` **file|blob|cosmosdb** - The storage type to use. Default=`file` + - `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -115,7 +115,7 @@ This section controls the storage mechanism used by the pipeline used for export #### Fields -- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file` +- `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -128,7 +128,7 @@ The section defines a secondary storage location for running incremental indexin #### Fields -- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file` +- `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md new file mode 100644 index 0000000000..37b17a0848 --- /dev/null +++ b/packages/graphrag-storage/README.md @@ -0,0 +1,14 @@ +# GraphRAG Storage + +```python +from graphrag_storage import StorageConfig, create_storage +from graphrag_storage.file_storage import FileStorage + +storage = create_storage( + StorageConfig( + type="FileStorage", # or FileStorage.__name__ + base_dir="output" + ) +) + +``` \ No newline at end of file diff --git a/packages/graphrag-storage/graphrag_storage/__init__.py b/packages/graphrag-storage/graphrag_storage/__init__.py new file mode 100644 index 0000000000..0684dfb889 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The GraphRAG Storage package.""" + +from graphrag_storage.storage import Storage +from graphrag_storage.storage_config import StorageConfig +from graphrag_storage.storage_factory import create_storage, register_storage + +__all__ = [ + "Storage", + "StorageConfig", + "create_storage", + "register_storage", +] diff --git a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py similarity index 81% rename from packages/graphrag/graphrag/storage/blob_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index 1435cb387d..9028259bf8 100644 --- a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Azure Blob Storage implementation of PipelineStorage.""" +"""Azure Blob Storage implementation of Storage.""" import logging import re @@ -12,15 +12,15 @@ from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient -from graphrag.storage.pipeline_storage import ( - PipelineStorage, +from graphrag_storage.storage import ( + Storage, get_timestamp_formatted_with_local_tz, ) logger = logging.getLogger(__name__) -class BlobPipelineStorage(PipelineStorage): +class AzureBlobStorage(Storage): """The Blob-Storage implementation.""" _connection_string: str | None @@ -28,20 +28,36 @@ class BlobPipelineStorage(PipelineStorage): _base_dir: str | None _encoding: str _storage_account_blob_url: str | None + _blob_service_client: BlobServiceClient + _storage_account_name: str | None - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, + base_dir: str | None = None, + connection_string: str | None = None, + storage_account_blob_url: str | None = None, + container_name: str | None = None, + encoding: str = "utf-8", + **kwargs: Any, + ) -> None: """Create a new BlobStorage instance.""" - connection_string = kwargs.get("connection_string") - storage_account_blob_url = kwargs.get("storage_account_blob_url") - base_dir = kwargs.get("base_dir") - container_name = kwargs["container_name"] - if container_name is None: - msg = "No container name provided for blob storage." - raise ValueError(msg) if connection_string is None and storage_account_blob_url is None: - msg = "No storage account blob url provided for blob storage." + msg = "AzureBlobStorage requires either a connection_string or storage_account_blob_url to be specified." + logger.error(msg) + raise ValueError(msg) + + if connection_string is not None and storage_account_blob_url is not None: + msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both." + logger.error(msg) raise ValueError(msg) + if container_name is None: + msg = "AzureBlobStorage requires a container_name to be specified." + logger.error(msg) + raise ValueError(msg) + + _validate_blob_container_name(container_name) + logger.info( "Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir ) @@ -49,16 +65,12 @@ def __init__(self, **kwargs: Any) -> None: self._blob_service_client = BlobServiceClient.from_connection_string( connection_string ) - else: - if storage_account_blob_url is None: - msg = "Either connection_string or storage_account_blob_url must be provided." - raise ValueError(msg) - + elif storage_account_blob_url: self._blob_service_client = BlobServiceClient( account_url=storage_account_blob_url, credential=DefaultAzureCredential(), ) - self._encoding = kwargs.get("encoding", "utf-8") + self._encoding = encoding self._container_name = container_name self._connection_string = connection_string self._base_dir = base_dir @@ -208,12 +220,12 @@ async def delete(self, key: str) -> None: async def clear(self) -> None: """Clear the cache.""" - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" if name is None: return self path = str(Path(self._base_dir) / name) if self._base_dir else name - return BlobPipelineStorage( + return AzureBlobStorage( connection_string=self._connection_string, container_name=self._container_name, encoding=self._encoding, @@ -245,7 +257,7 @@ async def get_creation_date(self, key: str) -> str: return "" -def validate_blob_container_name(container_name: str): +def _validate_blob_container_name(container_name: str) -> None: """ Check if the provided blob container name is valid based on Azure rules. @@ -267,32 +279,25 @@ def validate_blob_container_name(container_name: str): """ # Check the length of the name if len(container_name) < 3 or len(container_name) > 63: - return ValueError( - f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long." - ) + msg = f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long." + raise ValueError(msg) # Check if the name starts with a letter or number if not container_name[0].isalnum(): - return ValueError( - f"Container name must start with a letter or number. Starting character was {container_name[0]}." - ) + msg = f"Container name must start with a letter or number. Starting character was {container_name[0]}." + raise ValueError(msg) # Check for valid characters (letters, numbers, hyphen) and lowercase letters if not re.match(r"^[a-z0-9-]+$", container_name): - return ValueError( - f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}." - ) + msg = f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}." + raise ValueError(msg) # Check for consecutive hyphens if "--" in container_name: - return ValueError( - f"Container name cannot contain consecutive hyphens. Name provided was {container_name}." - ) + msg = f"Container name cannot contain consecutive hyphens. Name provided was {container_name}." + raise ValueError(msg) # Check for hyphens at the end of the name if container_name[-1] == "-": - return ValueError( - f"Container name cannot end with a hyphen. Name provided was {container_name}." - ) - - return True + msg = f"Container name cannot end with a hyphen. Name provided was {container_name}." + raise ValueError(msg) diff --git a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py similarity index 91% rename from packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index a12da0ee5f..4e4e034eb7 100644 --- a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -16,17 +16,17 @@ from azure.cosmos.exceptions import CosmosResourceNotFoundError from azure.cosmos.partition_key import PartitionKey from azure.identity import DefaultAzureCredential - from graphrag.logger.progress import Progress -from graphrag.storage.pipeline_storage import ( - PipelineStorage, + +from graphrag_storage.storage import ( + Storage, get_timestamp_formatted_with_local_tz, ) logger = logging.getLogger(__name__) -class CosmosDBPipelineStorage(PipelineStorage): +class AzureCosmosStorage(Storage): """The CosmosDB-Storage Implementation.""" _cosmos_client: CosmosClient @@ -39,28 +39,40 @@ class CosmosDBPipelineStorage(PipelineStorage): _encoding: str _no_id_prefixes: list[str] - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, + base_dir: str | None = None, + container_name: str | None = None, + connection_string: str | None = None, + cosmosdb_account_url: str | None = None, + **kwargs: Any, + ) -> None: """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") - cosmosdb_account_url = kwargs.get("cosmosdb_account_url") - connection_string = kwargs.get("connection_string") - database_name = kwargs["base_dir"] - container_name = kwargs["container_name"] - if not database_name: - msg = "No base_dir provided for database name" + database_name = base_dir + if database_name is None: + msg = "CosmosDB Storage requires a base_dir to be specified. This is used as the database name." + logger.error(msg) raise ValueError(msg) + if connection_string is None and cosmosdb_account_url is None: - msg = "connection_string or cosmosdb_account_url is required." + msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified." + logger.error(msg) + raise ValueError(msg) + + if connection_string is not None and cosmosdb_account_url is not None: + msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified, not both." + logger.error(msg) + raise ValueError(msg) + + if container_name is None: + msg = "CosmosDB Storage requires a container_name to be specified." + logger.error(msg) raise ValueError(msg) if connection_string: self._cosmos_client = CosmosClient.from_connection_string(connection_string) - else: - if cosmosdb_account_url is None: - msg = ( - "Either connection_string or cosmosdb_account_url must be provided." - ) - raise ValueError(msg) + elif cosmosdb_account_url: self._cosmos_client = CosmosClient( url=cosmosdb_account_url, credential=DefaultAzureCredential(), @@ -307,7 +319,7 @@ def keys(self) -> list[str]: msg = "CosmosDB storage does yet not support listing keys." raise NotImplementedError(msg) - def child(self, name: str | None) -> PipelineStorage: + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" return self diff --git a/packages/graphrag/graphrag/storage/file_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/file_storage.py similarity index 72% rename from packages/graphrag/graphrag/storage/file_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/file_storage.py index 52402c8bd6..61cb922ec6 100644 --- a/packages/graphrag/graphrag/storage/file_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/file_storage.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""File-based Storage implementation of PipelineStorage.""" +"""File-based Storage implementation of Storage.""" import logging import os @@ -16,26 +16,33 @@ from aiofiles.os import remove from aiofiles.ospath import exists -from graphrag.storage.pipeline_storage import ( - PipelineStorage, +from graphrag_storage.storage import ( + Storage, get_timestamp_formatted_with_local_tz, ) logger = logging.getLogger(__name__) -class FilePipelineStorage(PipelineStorage): +class FileStorage(Storage): """File storage class definition.""" - _base_dir: str + _base_dir: Path _encoding: str - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, base_dir: str | None = "", encoding: str = "utf-8", **kwargs: Any + ) -> None: """Create a file based storage.""" - self._base_dir = kwargs.get("base_dir", "") - self._encoding = kwargs.get("encoding", "utf-8") + if base_dir is None: + msg = "FileStorage requires a base_dir to be specified." + logger.error(msg) + raise ValueError(msg) + + self._base_dir = Path(base_dir).resolve() + self._encoding = encoding logger.info("Creating file storage at [%s]", self._base_dir) - Path(self._base_dir).mkdir(parents=True, exist_ok=True) + self._base_dir.mkdir(parents=True, exist_ok=True) def find( self, @@ -45,7 +52,7 @@ def find( logger.info( "Search [%s] for files matching [%s]", self._base_dir, file_pattern.pattern ) - all_files = list(Path(self._base_dir).rglob("**/*")) + all_files = list(self._base_dir.rglob("**/*")) logger.debug("All files and folders: %s", [file.name for file in all_files]) num_loaded = 0 num_total = len(all_files) @@ -53,7 +60,7 @@ def find( for file in all_files: match = file_pattern.search(f"{file}") if match: - filename = f"{file}".replace(str(Path(self._base_dir)), "", 1) + filename = f"{file}".replace(str(self._base_dir), "", 1) if filename.startswith(os.sep): filename = filename[1:] yield filename @@ -71,7 +78,7 @@ async def get( self, key: str, as_bytes: bool | None = False, encoding: str | None = None ) -> Any: """Get method definition.""" - file_path = join_path(self._base_dir, key) + file_path = _join_path(self._base_dir, key) if await self.has(key): return await self._read_file(file_path, as_bytes, encoding) @@ -101,7 +108,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: write_type = "wb" if is_bytes else "w" encoding = None if is_bytes else encoding or self._encoding async with aiofiles.open( - join_path(self._base_dir, key), + _join_path(self._base_dir, key), cast("Any", write_type), encoding=encoding, ) as f: @@ -109,35 +116,35 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: async def has(self, key: str) -> bool: """Has method definition.""" - return await exists(join_path(self._base_dir, key)) + return await exists(_join_path(self._base_dir, key)) async def delete(self, key: str) -> None: """Delete method definition.""" if await self.has(key): - await remove(join_path(self._base_dir, key)) + await remove(_join_path(self._base_dir, key)) async def clear(self) -> None: """Clear method definition.""" - for file in Path(self._base_dir).glob("*"): + for file in self._base_dir.glob("*"): if file.is_dir(): shutil.rmtree(file) else: file.unlink() - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" if name is None: return self - child_path = str(Path(self._base_dir) / Path(name)) - return FilePipelineStorage(base_dir=child_path, encoding=self._encoding) + child_path = str(self._base_dir / name) + return FileStorage(base_dir=child_path, encoding=self._encoding) def keys(self) -> list[str]: """Return the keys in the storage.""" - return [item.name for item in Path(self._base_dir).iterdir() if item.is_file()] + return [item.name for item in self._base_dir.iterdir() if item.is_file()] async def get_creation_date(self, key: str) -> str: """Get the creation date of a file.""" - file_path = Path(join_path(self._base_dir, key)) + file_path = _join_path(self._base_dir, key) creation_timestamp = file_path.stat().st_ctime creation_time_utc = datetime.fromtimestamp(creation_timestamp, tz=timezone.utc) @@ -145,6 +152,6 @@ async def get_creation_date(self, key: str) -> str: return get_timestamp_formatted_with_local_tz(creation_time_utc) -def join_path(file_path: str, file_name: str) -> Path: +def _join_path(file_path: Path, file_name: str) -> Path: """Join a path and a file. Independent of the OS.""" - return Path(file_path) / Path(file_name).parent / Path(file_name).name + return (file_path / Path(file_name).parent / Path(file_name).name).resolve() diff --git a/packages/graphrag/graphrag/storage/memory_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/memory_storage.py similarity index 82% rename from packages/graphrag/graphrag/storage/memory_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/memory_storage.py index 3567e3d1e3..7908d98a35 100644 --- a/packages/graphrag/graphrag/storage/memory_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/memory_storage.py @@ -1,24 +1,24 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A module containing 'InMemoryStorage' model.""" +"""In-memory storage implementation.""" from typing import TYPE_CHECKING, Any -from graphrag.storage.file_pipeline_storage import FilePipelineStorage +from graphrag_storage.file_storage import FileStorage if TYPE_CHECKING: - from graphrag.storage.pipeline_storage import PipelineStorage + from graphrag_storage.storage import Storage -class MemoryPipelineStorage(FilePipelineStorage): +class MemoryStorage(FileStorage): """In memory storage class definition.""" _storage: dict[str, Any] - def __init__(self): + def __init__(self, **kwargs: Any) -> None: """Init method definition.""" - super().__init__() + super().__init__(**kwargs) self._storage = {} async def get( @@ -69,9 +69,9 @@ async def clear(self) -> None: """Clear the storage.""" self._storage.clear() - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" - return MemoryPipelineStorage() + return MemoryStorage() def keys(self) -> list[str]: """Return the keys in the storage.""" diff --git a/packages/graphrag-storage/graphrag_storage/py.typed b/packages/graphrag-storage/graphrag_storage/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/graphrag/graphrag/storage/pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/storage.py similarity index 88% rename from packages/graphrag/graphrag/storage/pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/storage.py index 5c79921736..c9fe400331 100644 --- a/packages/graphrag/graphrag/storage/pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/storage.py @@ -1,17 +1,21 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A module containing 'PipelineStorage' model.""" +"""Abstract base class for storage.""" import re -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from collections.abc import Iterator from datetime import datetime from typing import Any -class PipelineStorage(metaclass=ABCMeta): - """Provide a storage interface for the pipeline. This is where the pipeline will store its output data.""" +class Storage(ABC): + """Provide a storage interface.""" + + @abstractmethod + def __init__(self, **kwargs: Any) -> None: + """Create a storage instance.""" @abstractmethod def find( @@ -69,7 +73,7 @@ async def clear(self) -> None: """Clear the storage.""" @abstractmethod - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" @abstractmethod diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py new file mode 100644 index 0000000000..0a8cf76893 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Storage configuration model.""" + +from pydantic import BaseModel, ConfigDict, Field + + +class StorageConfig(BaseModel): + """The default configuration section for storage.""" + + model_config = ConfigDict(extra="allow") + """Allow extra fields to support custom storage implementations.""" + + type: str = Field( + description="The storage type to use.", + default="FileStorage", + ) + + base_dir: str | None = Field( + description="The base directory for the output.", + default=None, + ) + + connection_string: str | None = Field( + description="The storage connection string to use.", + default=None, + ) + + container_name: str | None = Field( + description="The storage container name to use.", + default=None, + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", + default=None, + ) + cosmosdb_account_url: str | None = Field( + description="The cosmosdb account url to use.", + default=None, + ) diff --git a/packages/graphrag-storage/graphrag_storage/storage_factory.py b/packages/graphrag-storage/graphrag_storage/storage_factory.py new file mode 100644 index 0000000000..bbaec02258 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/storage_factory.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + + +"""Storage factory implementation.""" + +from collections.abc import Callable + +from graphrag_common.factory import Factory + +from graphrag_storage.azure_blob_storage import AzureBlobStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage +from graphrag_storage.file_storage import FileStorage +from graphrag_storage.memory_storage import MemoryStorage +from graphrag_storage.storage import Storage +from graphrag_storage.storage_config import StorageConfig + + +class _StorageFactory(Factory[Storage]): + """A factory class for storage implementations. + + Includes a method for users to register a custom storage implementation. + + Configuration arguments are passed to each storage implementation as kwargs + for individual enforcement of required/optional arguments. + """ + + +storage_factory = _StorageFactory() +storage_factory.register(FileStorage.__name__, FileStorage) +storage_factory.register(MemoryStorage.__name__, MemoryStorage) +storage_factory.register(AzureBlobStorage.__name__, AzureBlobStorage) +storage_factory.register(AzureCosmosStorage.__name__, AzureCosmosStorage) + + +def register_storage(storage: str, storage_initializer: Callable[..., Storage]) -> None: + """Register a custom storage implementation. + + Args + ---- + - storage: str + The storage id to register. + - storage_initializer: Callable[..., Storage] + The storage initializer to register. + """ + storage_factory.register(storage, storage_initializer) + + +def create_storage(config: StorageConfig) -> Storage: + """Create a storage implementation based on the given configuration. + + Args + ---- + - config: StorageConfig + The storage configuration to use. + + Returns + ------- + Storage + The created storage implementation. + """ + storage_strategy = config.type + + if storage_strategy not in storage_factory: + msg = f"StorageConfig.type '{storage_strategy}' is not registered in the StorageFactory. Registered types: {', '.join(storage_factory.keys())}" + raise ValueError(msg) + + return storage_factory.create(config.type, config.model_dump()) diff --git a/packages/graphrag-storage/pyproject.toml b/packages/graphrag-storage/pyproject.toml new file mode 100644 index 0000000000..fb21d4ad2f --- /dev/null +++ b/packages/graphrag-storage/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "graphrag-storage" +version = "2.7.0" +description = "GraphRAG storage package." +authors = [ + {name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"}, + {name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"}, + {name = "Chris Trevino", email = "chtrevin@microsoft.com"}, + {name = "David Tittsworth", email = "datittsw@microsoft.com"}, + {name = "Dayenne de Souza", email = "ddesouza@microsoft.com"}, + {name = "Derek Worthen", email = "deworthe@microsoft.com"}, + {name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"}, + {name = "Ha Trinh", email = "trinhha@microsoft.com"}, + {name = "Jonathan Larson", email = "jolarso@microsoft.com"}, + {name = "Josh Bradley", email = "joshbradley@microsoft.com"}, + {name = "Kate Lytvynets", email = "kalytv@microsoft.com"}, + {name = "Kenny Zhang", email = "zhangken@microsoft.com"}, + {name = "Mónica Carvajal"}, + {name = "Nathan Evans", email = "naevans@microsoft.com"}, + {name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"}, + {name = "Sarah Smith", email = "smithsarah@microsoft.com"}, +] +license = "MIT" +readme = "README.md" +license-files = ["LICENSE"] +requires-python = ">=3.10,<3.13" +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "azure-cosmos>=4.9.0", + "azure-identity>=1.19.0", + "azure-storage-blob>=12.24.0", + "graphrag-common==2.7.0", + "pydantic>=2.10.3", +] + +[project.urls] +Source = "https://github.com/microsoft/graphrag" + +[build-system] +requires = ["hatchling>=1.27.0,<2.0.0"] +build-backend = "hatchling.build" + diff --git a/packages/graphrag/graphrag/cache/factory.py b/packages/graphrag/graphrag/cache/factory.py index 971c22c6d5..ccbf1e200f 100644 --- a/packages/graphrag/graphrag/cache/factory.py +++ b/packages/graphrag/graphrag/cache/factory.py @@ -6,15 +6,15 @@ from __future__ import annotations from graphrag_common.factory import Factory +from graphrag_storage.azure_blob_storage import AzureBlobStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage +from graphrag_storage.file_storage import FileStorage from graphrag.cache.json_pipeline_cache import JsonPipelineCache from graphrag.cache.memory_pipeline_cache import InMemoryCache from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.enums import CacheType -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage -from graphrag.storage.file_pipeline_storage import FilePipelineStorage class CacheFactory(Factory[PipelineCache]): @@ -30,19 +30,19 @@ class CacheFactory(Factory[PipelineCache]): # --- register built-in cache implementations --- def create_file_cache(**kwargs) -> PipelineCache: """Create a file-based cache implementation.""" - storage = FilePipelineStorage(**kwargs) + storage = FileStorage(**kwargs) return JsonPipelineCache(storage) def create_blob_cache(**kwargs) -> PipelineCache: """Create a blob storage-based cache implementation.""" - storage = BlobPipelineStorage(**kwargs) + storage = AzureBlobStorage(**kwargs) return JsonPipelineCache(storage) def create_cosmosdb_cache(**kwargs) -> PipelineCache: """Create a CosmosDB-based cache implementation.""" - storage = CosmosDBPipelineStorage(**kwargs) + storage = AzureCosmosStorage(**kwargs) return JsonPipelineCache(storage) diff --git a/packages/graphrag/graphrag/cache/json_pipeline_cache.py b/packages/graphrag/graphrag/cache/json_pipeline_cache.py index 84cd180c52..22b438936e 100644 --- a/packages/graphrag/graphrag/cache/json_pipeline_cache.py +++ b/packages/graphrag/graphrag/cache/json_pipeline_cache.py @@ -6,17 +6,18 @@ import json from typing import Any +from graphrag_storage import Storage + from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.storage.pipeline_storage import PipelineStorage class JsonPipelineCache(PipelineCache): """File pipeline cache class definition.""" - _storage: PipelineStorage + _storage: Storage _encoding: str - def __init__(self, storage: PipelineStorage, encoding="utf-8"): + def __init__(self, storage: Storage, encoding="utf-8"): """Init method definition.""" self._storage = storage self._encoding = encoding diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 88449a6050..9d988ad928 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -7,6 +7,8 @@ from pathlib import Path from typing import ClassVar +from graphrag_storage.file_storage import FileStorage + from graphrag.config.embeddings import default_embeddings from graphrag.config.enums import ( AsyncType, @@ -17,7 +19,6 @@ ModelType, NounPhraseExtractorType, ReportingType, - StorageType, VectorStoreType, ) from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import ( @@ -230,7 +231,7 @@ class GlobalSearchDefaults: class StorageDefaults: """Default values for storage.""" - type: ClassVar[StorageType] = StorageType.file + type: str = FileStorage.__name__ base_dir: str | None = None connection_string: None = None container_name: None = None diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index 1cbccf74df..1cb70ddd16 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -50,7 +50,7 @@ input: storage: - type: {graphrag_config_defaults.input.storage.type.value} # or blob + type: {graphrag_config_defaults.input.storage.type} # or AzureBlobStorage, AzureCosmosStorage base_dir: "{graphrag_config_defaults.input.storage.base_dir}" file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] @@ -63,7 +63,7 @@ ## connection_string and container_name must be provided output: - type: {graphrag_config_defaults.output.type.value} # [file, blob, cosmosdb] + type: {graphrag_config_defaults.output.type} # or AzureBlobStorage, AzureCosmosStorage base_dir: "{graphrag_config_defaults.output.base_dir}" cache: diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 15d02eaf3a..e2bdf81f72 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -6,6 +6,8 @@ from pathlib import Path from devtools import pformat +from graphrag_storage import StorageConfig +from graphrag_storage.file_storage import FileStorage from pydantic import BaseModel, Field, model_validator import graphrag.config.defaults as defs @@ -29,7 +31,6 @@ from graphrag.config.models.prune_graph_config import PruneGraphConfig from graphrag.config.models.reporting_config import ReportingConfig from graphrag.config.models.snapshots_config import SnapshotsConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) @@ -135,7 +136,7 @@ def _validate_input_pattern(self) -> None: def _validate_input_base_dir(self) -> None: """Validate the input base directory.""" - if self.input.storage.type == defs.StorageType.file: + if self.input.storage.type == FileStorage.__name__: if not self.input.storage.base_dir: msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." raise ValueError(msg) @@ -159,7 +160,7 @@ def _validate_input_base_dir(self) -> None: def _validate_output_base_dir(self) -> None: """Validate the output base directory.""" - if self.output.type == defs.StorageType.file: + if self.output.type == FileStorage.__name__: if not self.output.base_dir: msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) @@ -175,7 +176,7 @@ def _validate_output_base_dir(self) -> None: def _validate_update_index_output_base_dir(self) -> None: """Validate the update index output base directory.""" - if self.update_index_output.type == defs.StorageType.file: + if self.update_index_output.type == FileStorage.__name__: if not self.update_index_output.base_dir: msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) diff --git a/packages/graphrag/graphrag/config/models/input_config.py b/packages/graphrag/graphrag/config/models/input_config.py index bc34d9402d..c3c30f6302 100644 --- a/packages/graphrag/graphrag/config/models/input_config.py +++ b/packages/graphrag/graphrag/config/models/input_config.py @@ -3,12 +3,12 @@ """Parameterization settings for the default configuration.""" +from graphrag_storage import StorageConfig from pydantic import BaseModel, Field import graphrag.config.defaults as defs from graphrag.config.defaults import graphrag_config_defaults from graphrag.config.enums import InputFileType -from graphrag.config.models.storage_config import StorageConfig class InputConfig(BaseModel): diff --git a/packages/graphrag/graphrag/config/models/storage_config.py b/packages/graphrag/graphrag/config/models/storage_config.py deleted file mode 100644 index 7491454c0a..0000000000 --- a/packages/graphrag/graphrag/config/models/storage_config.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Parameterization settings for the default configuration.""" - -from pathlib import Path - -from pydantic import BaseModel, Field, field_validator - -from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.enums import StorageType - - -class StorageConfig(BaseModel): - """The default configuration section for storage.""" - - type: StorageType | str = Field( - description="The storage type to use.", - default=graphrag_config_defaults.storage.type, - ) - base_dir: str | None = Field( - description="The base directory for the output.", - default=graphrag_config_defaults.storage.base_dir, - ) - - # Validate the base dir for multiple OS (use Path) - # if not using a cloud storage type. - @field_validator("base_dir", mode="before") - @classmethod - def validate_base_dir(cls, value, info): - """Ensure that base_dir is a valid filesystem path when using local storage.""" - # info.data contains other field values, including 'type' - if info.data.get("type") != StorageType.file: - return value - return str(Path(value)) - - connection_string: str | None = Field( - description="The storage connection string to use.", - default=graphrag_config_defaults.storage.connection_string, - ) - container_name: str | None = Field( - description="The storage container name to use.", - default=graphrag_config_defaults.storage.container_name, - ) - storage_account_blob_url: str | None = Field( - description="The storage account blob url to use.", - default=graphrag_config_defaults.storage.storage_account_blob_url, - ) - cosmosdb_account_url: str | None = Field( - description="The cosmosdb account url to use.", - default=graphrag_config_defaults.storage.cosmosdb_account_url, - ) diff --git a/packages/graphrag/graphrag/index/input/input_reader.py b/packages/graphrag/graphrag/index/input/input_reader.py index ed0add9f97..98a713e509 100644 --- a/packages/graphrag/graphrag/index/input/input_reader.py +++ b/packages/graphrag/graphrag/index/input/input_reader.py @@ -13,8 +13,9 @@ import pandas as pd if TYPE_CHECKING: + from graphrag_storage import Storage + from graphrag.config.models.input_config import InputConfig - from graphrag.storage.pipeline_storage import PipelineStorage logger = logging.getLogger(__name__) @@ -22,7 +23,7 @@ class InputReader(metaclass=ABCMeta): """Provide a cache interface for the pipeline.""" - def __init__(self, storage: PipelineStorage, config: InputConfig, **kwargs): + def __init__(self, storage: Storage, config: InputConfig, **kwargs): self._storage = storage self._config = config diff --git a/packages/graphrag/graphrag/index/operations/snapshot_graphml.py b/packages/graphrag/graphrag/index/operations/snapshot_graphml.py index c1eb9b0688..9124038401 100644 --- a/packages/graphrag/graphrag/index/operations/snapshot_graphml.py +++ b/packages/graphrag/graphrag/index/operations/snapshot_graphml.py @@ -4,14 +4,13 @@ """A module containing snapshot_graphml method definition.""" import networkx as nx - -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage async def snapshot_graphml( input: str | nx.Graph, name: str, - storage: PipelineStorage, + storage: Storage, ) -> None: """Take a entire snapshot of a graph to standard graphml format.""" graphml = input if isinstance(input, str) else "\n".join(nx.generate_graphml(input)) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index a0b2011eab..5e87249550 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -12,6 +12,7 @@ from typing import Any import pandas as pd +from graphrag_storage import Storage from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -19,7 +20,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.api import create_cache_from_config, create_storage_from_config from graphrag.utils.storage import load_table_from_storage, write_table_to_storage @@ -156,8 +156,8 @@ async def _dump_json(context: PipelineRunContext) -> None: async def _copy_previous_output( - storage: PipelineStorage, - copy_storage: PipelineStorage, + storage: Storage, + copy_storage: Storage, ): for file in storage.find(re.compile(r"\.parquet$")): base_name = file[0].replace(".parquet", "") diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index 52b1f0bd31..372023c879 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -3,6 +3,9 @@ """Utility functions for the GraphRAG run module.""" +from graphrag_storage import Storage +from graphrag_storage.memory_storage import MemoryStorage + from graphrag.cache.memory_pipeline_cache import InMemoryCache from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks @@ -12,15 +15,13 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.api import create_storage_from_config def create_run_context( - input_storage: PipelineStorage | None = None, - output_storage: PipelineStorage | None = None, - previous_storage: PipelineStorage | None = None, + input_storage: Storage | None = None, + output_storage: Storage | None = None, + previous_storage: Storage | None = None, cache: PipelineCache | None = None, callbacks: WorkflowCallbacks | None = None, stats: PipelineRunStats | None = None, @@ -28,9 +29,9 @@ def create_run_context( ) -> PipelineRunContext: """Create the run context for the pipeline.""" return PipelineRunContext( - input_storage=input_storage or MemoryPipelineStorage(), - output_storage=output_storage or MemoryPipelineStorage(), - previous_storage=previous_storage or MemoryPipelineStorage(), + input_storage=input_storage or MemoryStorage(), + output_storage=output_storage or MemoryStorage(), + previous_storage=previous_storage or MemoryStorage(), cache=cache or InMemoryCache(), callbacks=callbacks or NoopWorkflowCallbacks(), stats=stats or PipelineRunStats(), @@ -50,7 +51,7 @@ def create_callback_chain( def get_update_storages( config: GraphRagConfig, timestamp: str -) -> tuple[PipelineStorage, PipelineStorage, PipelineStorage]: +) -> tuple[Storage, Storage, Storage]: """Get storage objects for the update index run.""" output_storage = create_storage_from_config(config.output) update_storage = create_storage_from_config(config.update_index_output) diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index ef2e1f7ea5..465ec7214c 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -10,7 +10,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage @dataclass @@ -18,11 +18,11 @@ class PipelineRunContext: """Provides the context for the current pipeline run.""" stats: PipelineRunStats - input_storage: PipelineStorage + input_storage: Storage "Storage for input documents." - output_storage: PipelineStorage + output_storage: Storage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." - previous_storage: PipelineStorage + previous_storage: Storage "Storage for previous pipeline run when running in update mode." cache: PipelineCache "Cache instance for reading previous LLM responses." diff --git a/packages/graphrag/graphrag/index/update/incremental_index.py b/packages/graphrag/graphrag/index/update/incremental_index.py index ac56e30df4..81f917e187 100644 --- a/packages/graphrag/graphrag/index/update/incremental_index.py +++ b/packages/graphrag/graphrag/index/update/incremental_index.py @@ -7,8 +7,8 @@ import numpy as np import pandas as pd +from graphrag_storage import Storage -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import ( load_table_from_storage, write_table_to_storage, @@ -31,16 +31,14 @@ class InputDelta: deleted_inputs: pd.DataFrame -async def get_delta_docs( - input_dataset: pd.DataFrame, storage: PipelineStorage -) -> InputDelta: +async def get_delta_docs(input_dataset: pd.DataFrame, storage: Storage) -> InputDelta: """Get the delta between the input dataset and the final documents. Parameters ---------- input_dataset : pd.DataFrame The input dataset. - storage : PipelineStorage + storage : Storage The Pipeline storage. Returns @@ -65,9 +63,9 @@ async def get_delta_docs( async def concat_dataframes( name: str, - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, ) -> pd.DataFrame: """Concatenate dataframes.""" old_df = await load_table_from_storage(name, previous_storage) diff --git a/packages/graphrag/graphrag/index/workflows/load_update_documents.py b/packages/graphrag/graphrag/index/workflows/load_update_documents.py index 7755091017..e68471eae7 100644 --- a/packages/graphrag/graphrag/index/workflows/load_update_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_update_documents.py @@ -6,6 +6,7 @@ import logging import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.input.factory import InputReaderFactory @@ -13,7 +14,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.incremental_index import get_delta_docs -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ async def run_workflow( async def load_update_documents( input_reader: InputReader, - previous_storage: PipelineStorage, + previous_storage: Storage, ) -> pd.DataFrame: """Load and parse update-only input documents into a standard format.""" input_documents = await input_reader.read_files() diff --git a/packages/graphrag/graphrag/index/workflows/update_communities.py b/packages/graphrag/graphrag/index/workflows/update_communities.py index b7e3e6a343..da4fdef147 100644 --- a/packages/graphrag/graphrag/index/workflows/update_communities.py +++ b/packages/graphrag/graphrag/index/workflows/update_communities.py @@ -5,12 +5,13 @@ import logging +from graphrag_storage import Storage + from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_communities -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -37,9 +38,9 @@ async def run_workflow( async def _update_communities( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, ) -> dict: """Update the communities output.""" old_communities = await load_table_from_storage("communities", previous_storage) diff --git a/packages/graphrag/graphrag/index/workflows/update_community_reports.py b/packages/graphrag/graphrag/index/workflows/update_community_reports.py index 42576aca27..790f9fc296 100644 --- a/packages/graphrag/graphrag/index/workflows/update_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/update_community_reports.py @@ -6,13 +6,13 @@ import logging import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_community_reports -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -43,9 +43,9 @@ async def run_workflow( async def _update_community_reports( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, community_id_mapping: dict, ) -> pd.DataFrame: """Update the community reports output.""" diff --git a/packages/graphrag/graphrag/index/workflows/update_covariates.py b/packages/graphrag/graphrag/index/workflows/update_covariates.py index f0bf29a6ae..09f8b4053d 100644 --- a/packages/graphrag/graphrag/index/workflows/update_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/update_covariates.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import ( load_table_from_storage, storage_has_table, @@ -43,9 +43,9 @@ async def run_workflow( async def _update_covariates( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, ) -> None: """Update the covariates output.""" old_covariates = await load_table_from_storage("covariates", previous_storage) diff --git a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py index 1245303559..2ddd171457 100644 --- a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py +++ b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py @@ -6,6 +6,7 @@ import logging import pandas as pd +from graphrag_storage import Storage from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -17,7 +18,6 @@ from graphrag.index.update.relationships import _update_and_merge_relationships from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships from graphrag.language_model.manager import ModelManager -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -55,9 +55,9 @@ async def run_workflow( async def _update_entities_and_relationships( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, config: GraphRagConfig, cache: PipelineCache, callbacks: WorkflowCallbacks, diff --git a/packages/graphrag/graphrag/index/workflows/update_text_units.py b/packages/graphrag/graphrag/index/workflows/update_text_units.py index 392533f16b..c97f89ce7a 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_units.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -40,9 +40,9 @@ async def run_workflow( async def _update_text_units( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, entity_id_mapping: dict, ) -> pd.DataFrame: """Update the text units output.""" diff --git a/packages/graphrag/graphrag/storage/__init__.py b/packages/graphrag/graphrag/storage/__init__.py index b21f077cb1..94146bcd02 100644 --- a/packages/graphrag/graphrag/storage/__init__.py +++ b/packages/graphrag/graphrag/storage/__init__.py @@ -2,3 +2,7 @@ # Licensed under the MIT License """The storage package root.""" + +from graphrag_storage import create_storage, register_storage + +__all__ = ["create_storage", "register_storage"] diff --git a/packages/graphrag/graphrag/storage/factory.py b/packages/graphrag/graphrag/storage/factory.py deleted file mode 100644 index 738a46420b..0000000000 --- a/packages/graphrag/graphrag/storage/factory.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Factory functions for creating storage.""" - -from __future__ import annotations - -from graphrag_common.factory import Factory - -from graphrag.config.enums import StorageType -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage -from graphrag.storage.file_pipeline_storage import FilePipelineStorage -from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.storage.pipeline_storage import PipelineStorage - - -class StorageFactory(Factory[PipelineStorage]): - """A factory class for storage implementations. - - Includes a method for users to register a custom storage implementation. - - Configuration arguments are passed to each storage implementation as kwargs - for individual enforcement of required/optional arguments. - """ - - -# --- register built-in storage implementations --- -storage_factory = StorageFactory() -storage_factory.register(StorageType.blob.value, BlobPipelineStorage) -storage_factory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage) -storage_factory.register(StorageType.file.value, FilePipelineStorage) -storage_factory.register(StorageType.memory.value, MemoryPipelineStorage) diff --git a/packages/graphrag/graphrag/utils/api.py b/packages/graphrag/graphrag/utils/api.py index f264c1a9ed..82d86bb544 100644 --- a/packages/graphrag/graphrag/utils/api.py +++ b/packages/graphrag/graphrag/utils/api.py @@ -6,14 +6,13 @@ from pathlib import Path from typing import Any +from graphrag_storage import Storage, StorageConfig, create_storage + from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.embeddings import create_index_name from graphrag.config.models.cache_config import CacheConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.storage.factory import StorageFactory -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.vector_stores.base import ( BaseVectorStore, ) @@ -101,12 +100,10 @@ def load_search_prompt(prompt_config: str | None) -> str | None: return None -def create_storage_from_config(output: StorageConfig) -> PipelineStorage: +def create_storage_from_config(output: StorageConfig) -> Storage: """Create a storage object from the config.""" - storage_config = output.model_dump() - return StorageFactory().create( - storage_config["type"], - storage_config, + return create_storage( + output, ) diff --git a/packages/graphrag/graphrag/utils/storage.py b/packages/graphrag/graphrag/utils/storage.py index 8534330a15..852d066091 100644 --- a/packages/graphrag/graphrag/utils/storage.py +++ b/packages/graphrag/graphrag/utils/storage.py @@ -7,13 +7,12 @@ from io import BytesIO import pandas as pd - -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage logger = logging.getLogger(__name__) -async def load_table_from_storage(name: str, storage: PipelineStorage) -> pd.DataFrame: +async def load_table_from_storage(name: str, storage: Storage) -> pd.DataFrame: """Load a parquet from the storage instance.""" filename = f"{name}.parquet" if not await storage.has(filename): @@ -28,17 +27,17 @@ async def load_table_from_storage(name: str, storage: PipelineStorage) -> pd.Dat async def write_table_to_storage( - table: pd.DataFrame, name: str, storage: PipelineStorage + table: pd.DataFrame, name: str, storage: Storage ) -> None: """Write a table to storage.""" await storage.set(f"{name}.parquet", table.to_parquet()) -async def delete_table_from_storage(name: str, storage: PipelineStorage) -> None: +async def delete_table_from_storage(name: str, storage: Storage) -> None: """Delete a table to storage.""" await storage.delete(f"{name}.parquet") -async def storage_has_table(name: str, storage: PipelineStorage) -> bool: +async def storage_has_table(name: str, storage: Storage) -> bool: """Check if a table exists in storage.""" return await storage.has(f"{name}.parquet") diff --git a/packages/graphrag/pyproject.toml b/packages/graphrag/pyproject.toml index a7b97f1f0f..7b7eec259d 100644 --- a/packages/graphrag/pyproject.toml +++ b/packages/graphrag/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "devtools>=0.12.2", "environs>=11.0.0", "graphrag-common==2.7.0", + "graphrag-storage==2.7.0", "graspologic-native>=1.2.5", "json-repair>=0.30.3", "lancedb>=0.17.0", diff --git a/pyproject.toml b/pyproject.toml index 3cb1ae1d67..1979df5cb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ members = ["packages/*"] [tool.uv.sources] graphrag-common = { workspace = true } +graphrag-storage = { workspace = true } # Keep poethepoet for task management to minimize changes [tool.poe.tasks] @@ -69,6 +70,7 @@ _semversioner_changelog = "semversioner changelog > CHANGELOG.md" # Add more update toml tasks as packages are added _semversioner_update_graphrag_toml_version = "update-toml update --file packages/graphrag/pyproject.toml --path project.version --value $(uv run semversioner current-version)" _semversioner_update_graphrag_common_toml_version = "update-toml update --file packages/graphrag-common/pyproject.toml --path project.version --value $(uv run semversioner current-version)" +_semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)" _semversioner_update_workspace_dependency_versions = "python -m scripts.update_workspace_dependency_versions" semversioner_add = "semversioner add-change" coverage_report = 'coverage report --omit "**/tests/**" --show-missing' @@ -103,6 +105,7 @@ sequence = [ # Add more update toml tasks as packages are added '_semversioner_update_graphrag_toml_version', '_semversioner_update_graphrag_common_toml_version', + '_semversioner_update_graphrag_storage_toml_version', '_semversioner_update_workspace_dependency_versions', '_sync', ] @@ -220,6 +223,7 @@ convention = "numpy" include = [ "packages/graphrag/graphrag", "packages/graphrag-common/graphrag_common", + "packages/graphrag-storage/graphrag_storage", "tests" ] exclude = ["**/node_modules", "**/__pycache__"] diff --git a/tests/integration/storage/test_blob_pipeline_storage.py b/tests/integration/storage/test_blob_storage.py similarity index 94% rename from tests/integration/storage/test_blob_pipeline_storage.py rename to tests/integration/storage/test_blob_storage.py index 818b588bd6..44216700da 100644 --- a/tests/integration/storage/test_blob_pipeline_storage.py +++ b/tests/integration/storage/test_blob_storage.py @@ -5,14 +5,14 @@ import re from datetime import datetime -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag_storage.azure_blob_storage import AzureBlobStorage # cspell:disable-next-line well-known-key WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" async def test_find(): - storage = BlobPipelineStorage( + storage = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testfind", ) @@ -42,7 +42,7 @@ async def test_find(): async def test_dotprefix(): - storage = BlobPipelineStorage( + storage = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testfind", path_prefix=".", @@ -56,7 +56,7 @@ async def test_dotprefix(): async def test_get_creation_date(): - storage = BlobPipelineStorage( + storage = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testfind", path_prefix=".", @@ -74,7 +74,7 @@ async def test_get_creation_date(): async def test_child(): - parent = BlobPipelineStorage( + parent = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testchild", ) diff --git a/tests/integration/storage/test_cosmosdb_storage.py b/tests/integration/storage/test_cosmosdb_storage.py index 3d6128872f..9f85d93e0f 100644 --- a/tests/integration/storage/test_cosmosdb_storage.py +++ b/tests/integration/storage/test_cosmosdb_storage.py @@ -8,7 +8,7 @@ from datetime import datetime import pytest -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage # cspell:disable-next-line well-known-key WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" @@ -21,7 +21,7 @@ async def test_find(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testfind", container_name="testfindcontainer", @@ -64,20 +64,20 @@ async def test_find(): async def test_child(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testchild", container_name="testchildcontainer", ) try: child_storage = storage.child("child") - assert type(child_storage) is CosmosDBPipelineStorage + assert type(child_storage) is AzureCosmosStorage finally: await storage.clear() async def test_clear(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testclear", container_name="testclearcontainer", @@ -107,7 +107,7 @@ async def test_clear(): async def test_get_creation_date(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testclear", container_name="testclearcontainer", diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 87a2960dbc..4d65137a93 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -8,13 +8,11 @@ import sys import pytest -from graphrag.config.enums import StorageType -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage -from graphrag.storage.factory import StorageFactory -from graphrag.storage.file_pipeline_storage import FilePipelineStorage -from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage, StorageConfig, create_storage, register_storage +from graphrag_storage.azure_blob_storage import AzureBlobStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage +from graphrag_storage.file_storage import FileStorage +from graphrag_storage.memory_storage import MemoryStorage # cspell:disable-next-line well-known-key WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" @@ -24,14 +22,14 @@ @pytest.mark.skip(reason="Blob storage emulator is not available in this environment") def test_create_blob_storage(): - kwargs = { - "type": "blob", - "connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, - "base_dir": "testbasedir", - "container_name": "testcontainer", - } - storage = StorageFactory().create(StorageType.blob.value, kwargs) - assert isinstance(storage, BlobPipelineStorage) + config = StorageConfig( + type=AzureBlobStorage.__name__, + connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + base_dir="testbasedir", + container_name="testcontainer", + ) + storage = create_storage(config) + assert isinstance(storage, AzureBlobStorage) @pytest.mark.skipif( @@ -39,63 +37,57 @@ def test_create_blob_storage(): reason="cosmosdb emulator is only available on windows runners at this time", ) def test_create_cosmosdb_storage(): - kwargs = { - "type": "cosmosdb", - "connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, - "base_dir": "testdatabase", - "container_name": "testcontainer", - } - storage = StorageFactory().create(StorageType.cosmosdb.value, kwargs) - assert isinstance(storage, CosmosDBPipelineStorage) + config = StorageConfig( + type=AzureCosmosStorage.__name__, + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + base_dir="testdatabase", + container_name="testcontainer", + ) + storage = create_storage(config) + assert isinstance(storage, AzureCosmosStorage) def test_create_file(): - kwargs = {"type": "file", "base_dir": "/tmp/teststorage"} - storage = StorageFactory().create(StorageType.file.value, kwargs) - assert isinstance(storage, FilePipelineStorage) + config = StorageConfig( + type=FileStorage.__name__, + base_dir="/tmp/teststorage", + ) + storage = create_storage(config) + assert isinstance(storage, FileStorage) def test_create_memory_storage(): - kwargs = {} # MemoryPipelineStorage doesn't accept any constructor parameters - storage = StorageFactory().create(StorageType.memory.value, kwargs) - assert isinstance(storage, MemoryPipelineStorage) + config = StorageConfig( + type=MemoryStorage.__name__, + ) + storage = create_storage(config) + assert isinstance(storage, MemoryStorage) def test_register_and_create_custom_storage(): """Test registering and creating a custom storage type.""" from unittest.mock import MagicMock - # Create a mock that satisfies the PipelineStorage interface - custom_storage_class = MagicMock(spec=PipelineStorage) + # Create a mock that satisfies the Storage interface + custom_storage_class = MagicMock(spec=Storage) # Make the mock return a mock instance when instantiated instance = MagicMock() - # We can set attributes on the mock instance, even if they don't exist on PipelineStorage + # We can set attributes on the mock instance, even if they don't exist on Storage instance.initialized = True custom_storage_class.return_value = instance - StorageFactory().register("custom", lambda **kwargs: custom_storage_class(**kwargs)) - storage = StorageFactory().create("custom", {}) + register_storage("custom", lambda **kwargs: custom_storage_class(**kwargs)) + storage = create_storage(StorageConfig(type="custom")) assert custom_storage_class.called assert storage is instance # Access the attribute we set on our mock assert storage.initialized is True # type: ignore # Attribute only exists on our mock - # Check if it's in the list of registered storage types - assert "custom" in StorageFactory() - - -def test_get_storage_types(): - # Check that built-in types are registered - assert StorageType.file.value in StorageFactory() - assert StorageType.memory.value in StorageFactory() - assert StorageType.blob.value in StorageFactory() - assert StorageType.cosmosdb.value in StorageFactory() - def test_create_unknown_storage(): with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."): - StorageFactory().create("unknown") + create_storage(StorageConfig(type="unknown")) def test_register_class_directly_works(): @@ -104,9 +96,7 @@ def test_register_class_directly_works(): from collections.abc import Iterator from typing import Any - from graphrag.storage.pipeline_storage import PipelineStorage - - class CustomStorage(PipelineStorage): + class CustomStorage(Storage): def __init__(self, **kwargs): pass @@ -133,7 +123,7 @@ async def has(self, key: str) -> bool: async def clear(self) -> None: pass - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": return self def keys(self) -> list[str]: @@ -143,11 +133,8 @@ async def get_creation_date(self, key: str) -> str: return "2024-01-01 00:00:00 +0000" # StorageFactory allows registering classes directly (no TypeError) - StorageFactory().register("custom_class", CustomStorage) - - # Verify it was registered - assert "custom_class" in StorageFactory() + register_storage("custom_class", CustomStorage) # Test creating an instance - storage = StorageFactory().create("custom_class") + storage = create_storage(StorageConfig(type="custom_class")) assert isinstance(storage, CustomStorage) diff --git a/tests/integration/storage/test_file_pipeline_storage.py b/tests/integration/storage/test_file_storage.py similarity index 86% rename from tests/integration/storage/test_file_pipeline_storage.py rename to tests/integration/storage/test_file_storage.py index 95e329b6bf..b6edc77b03 100644 --- a/tests/integration/storage/test_file_pipeline_storage.py +++ b/tests/integration/storage/test_file_storage.py @@ -7,17 +7,15 @@ from datetime import datetime from pathlib import Path -from graphrag.storage.file_pipeline_storage import ( - FilePipelineStorage, +from graphrag_storage.file_storage import ( + FileStorage, ) __dirname__ = os.path.dirname(__file__) async def test_find(): - storage = FilePipelineStorage( - base_dir="tests/fixtures/text/input", - ) + storage = FileStorage(base_dir="tests/fixtures/text/input") items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) assert items == [str(Path("dulce.txt"))] output = await storage.get("dulce.txt") @@ -32,7 +30,7 @@ async def test_find(): async def test_get_creation_date(): - storage = FilePipelineStorage( + storage = FileStorage( base_dir="tests/fixtures/text/input", ) @@ -45,7 +43,7 @@ async def test_get_creation_date(): async def test_child(): - storage = FilePipelineStorage() + storage = FileStorage() storage = storage.child("tests/fixtures/text/input") items = list(storage.find(re.compile(r".*\.txt$"))) assert items == [str(Path("dulce.txt"))] diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 9821bed551..53205c7c09 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -17,7 +17,7 @@ from graphrag.query.context_builder.community_context import ( NO_COMMUNITY_RECORDS_WARNING, ) -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag_storage.azure_blob_storage import AzureBlobStorage logger = logging.getLogger(__name__) @@ -94,7 +94,7 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non input_base_dir = azure.get("input_base_dir") root = Path(input_path) - input_storage = BlobPipelineStorage( + input_storage = AzureBlobStorage( connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING, container_name=input_container, ) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 001518f62a..0f2235a58b 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -25,11 +25,11 @@ from graphrag.config.models.prune_graph_config import PruneGraphConfig from graphrag.config.models.reporting_config import ReportingConfig from graphrag.config.models.snapshots_config import SnapshotsConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) from graphrag.config.models.vector_store_config import VectorStoreConfig +from graphrag_storage import StorageConfig from pydantic import BaseModel FAKE_API_KEY = "NOT_AN_API_KEY" diff --git a/tests/unit/indexing/cache/test_file_pipeline_cache.py b/tests/unit/indexing/cache/test_file_pipeline_cache.py index c392b4e08e..c672d3718f 100644 --- a/tests/unit/indexing/cache/test_file_pipeline_cache.py +++ b/tests/unit/indexing/cache/test_file_pipeline_cache.py @@ -5,15 +5,15 @@ import unittest from graphrag.cache.json_pipeline_cache import JsonPipelineCache -from graphrag.storage.file_pipeline_storage import ( - FilePipelineStorage, +from graphrag_storage.file_storage import ( + FileStorage, ) TEMP_DIR = "./.tmp" def create_cache(): - storage = FilePipelineStorage(base_dir=os.path.join(os.getcwd(), ".tmp")) + storage = FileStorage(base_dir=os.path.join(os.getcwd(), ".tmp")) return JsonPipelineCache(storage) diff --git a/tests/unit/indexing/input/test_csv_loader.py b/tests/unit/indexing/input/test_csv_loader.py index 8a6b0e351d..72a33ff749 100644 --- a/tests/unit/indexing/input/test_csv_loader.py +++ b/tests/unit/indexing/input/test_csv_loader.py @@ -3,9 +3,9 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import InputReaderFactory from graphrag.utils.api import create_storage_from_config +from graphrag_storage import StorageConfig async def test_csv_loader_one_file(): diff --git a/tests/unit/indexing/input/test_json_loader.py b/tests/unit/indexing/input/test_json_loader.py index 1ce7001aab..bffbe6a630 100644 --- a/tests/unit/indexing/input/test_json_loader.py +++ b/tests/unit/indexing/input/test_json_loader.py @@ -3,9 +3,9 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import InputReaderFactory from graphrag.utils.api import create_storage_from_config +from graphrag_storage import StorageConfig async def test_json_loader_one_file_one_object(): diff --git a/tests/unit/indexing/input/test_txt_loader.py b/tests/unit/indexing/input/test_txt_loader.py index 239f622d72..9133fbd91e 100644 --- a/tests/unit/indexing/input/test_txt_loader.py +++ b/tests/unit/indexing/input/test_txt_loader.py @@ -3,9 +3,9 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import InputReaderFactory from graphrag.utils.api import create_storage_from_config +from graphrag_storage import StorageConfig async def test_txt_loader_one_file(): diff --git a/tests/unit/load_config/fixtures/config.yaml b/tests/unit/load_config/fixtures/config.yaml new file mode 100644 index 0000000000..a54919d1eb --- /dev/null +++ b/tests/unit/load_config/fixtures/config.yaml @@ -0,0 +1,10 @@ +name: test_name +value: 100 +nested: + nested_str: nested_value + nested_int: 42 +nested_list: + - nested_str: list_value_1 + nested_int: 7 + - nested_str: list_value_2 + nested_int: 8 \ No newline at end of file diff --git a/uv.lock b/uv.lock index 6ef295df9d..8bca344c84 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,7 @@ members = [ "graphrag", "graphrag-common", "graphrag-monorepo", + "graphrag-storage", ] [[package]] @@ -1042,6 +1043,7 @@ dependencies = [ { name = "devtools" }, { name = "environs" }, { name = "graphrag-common" }, + { name = "graphrag-storage" }, { name = "graspologic-native" }, { name = "json-repair" }, { name = "lancedb" }, @@ -1073,6 +1075,7 @@ requires-dist = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "environs", specifier = ">=11.0.0" }, { name = "graphrag-common", editable = "packages/graphrag-common" }, + { name = "graphrag-storage", editable = "packages/graphrag-storage" }, { name = "graspologic-native", specifier = ">=1.2.5" }, { name = "json-repair", specifier = ">=0.30.3" }, { name = "lancedb", specifier = ">=0.17.0" }, @@ -1160,6 +1163,27 @@ dev = [ { name = "update-toml", specifier = ">=0.2.1" }, ] +[[package]] +name = "graphrag-storage" +version = "2.7.0" +source = { editable = "packages/graphrag-storage" } +dependencies = [ + { name = "azure-cosmos" }, + { name = "azure-identity" }, + { name = "azure-storage-blob" }, + { name = "graphrag-common" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "azure-cosmos", specifier = ">=4.9.0" }, + { name = "azure-identity", specifier = ">=1.19.0" }, + { name = "azure-storage-blob", specifier = ">=12.24.0" }, + { name = "graphrag-common", editable = "packages/graphrag-common" }, + { name = "pydantic", specifier = ">=2.10.3" }, +] + [[package]] name = "graspologic-native" version = "1.2.5" From 7e41278e80c78bb2ac95d207a37786cc0fab5b67 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 11 Nov 2025 10:02:45 -0800 Subject: [PATCH 02/17] Fix integration tests. --- .../graphrag-storage/graphrag_storage/storage_factory.py | 2 +- tests/integration/storage/test_factory.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/storage_factory.py b/packages/graphrag-storage/graphrag_storage/storage_factory.py index bbaec02258..d1ab2f4db0 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_factory.py +++ b/packages/graphrag-storage/graphrag_storage/storage_factory.py @@ -62,7 +62,7 @@ def create_storage(config: StorageConfig) -> Storage: storage_strategy = config.type if storage_strategy not in storage_factory: - msg = f"StorageConfig.type '{storage_strategy}' is not registered in the StorageFactory. Registered types: {', '.join(storage_factory.keys())}" + msg = f"StorageConfig.type '{storage_strategy}' is not registered in the StorageFactory. Registered types: {', '.join(storage_factory.keys())}." raise ValueError(msg) return storage_factory.create(config.type, config.model_dump()) diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 4d65137a93..83677a4764 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -58,6 +58,7 @@ def test_create_file(): def test_create_memory_storage(): config = StorageConfig( + base_dir="", type=MemoryStorage.__name__, ) storage = create_storage(config) @@ -86,7 +87,10 @@ def test_register_and_create_custom_storage(): def test_create_unknown_storage(): - with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."): + with pytest.raises( + ValueError, + match="StorageConfig\\.type 'unknown' is not registered in the StorageFactory\\. Registered types: FileStorage, MemoryStorage, AzureBlobStorage, AzureCosmosStorage, custom\\.", + ): create_storage(StorageConfig(type="unknown")) From 4b7e5bc1c751f0e96f959ed4cc501455f2bf032b Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Wed, 12 Nov 2025 08:02:00 -0800 Subject: [PATCH 03/17] Implement copilot feedback. --- packages/graphrag-storage/README.md | 56 ++++++++++++++++++++--- packages/graphrag-storage/pyproject.toml | 2 + tests/integration/storage/test_factory.py | 2 +- uv.lock | 4 ++ 4 files changed, 57 insertions(+), 7 deletions(-) diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md index 37b17a0848..cd31bb0658 100644 --- a/packages/graphrag-storage/README.md +++ b/packages/graphrag-storage/README.md @@ -1,14 +1,58 @@ # GraphRAG Storage +## Basic + ```python +import asyncio from graphrag_storage import StorageConfig, create_storage from graphrag_storage.file_storage import FileStorage -storage = create_storage( - StorageConfig( - type="FileStorage", # or FileStorage.__name__ - base_dir="output" +async def run(): + storage = create_storage( + StorageConfig( + type="FileStorage", # or FileStorage.__name__ + base_dir="output" + ) + ) + + await storage.set("my_key", "value") + print(await storage.get("my_key")) + +if __name__ == "__main__": + asyncio.run(run()) +``` + +## Custom Storage + +```python +import asyncio +from typing import Any +from graphrag_storage import Storage, StorageConfig, create_storage, register_storage + +class MyStorage(Storage): + def __init__(self, some_setting: str, **kwargs: Any): + # Validate settings and initialize + ... + + #Implement rest of interface + ... + +register_storage("MyStorage", MyStorage) + +async def run(): + storage = create_storage( + StorageConfig( + type="MyStorage" + some_setting="My Setting" + ) ) -) + # Or use the factory directly to instantiate with a dict instead of using + # StorageConfig + create_factory + # from graphrag_storage.storage_factory import storage_factory + # storage = storage_factory.create(strategy="MyStorage", init_args={"some_setting": "My Setting"}) + + await storage.set("my_key", "value") + print(await storage.get("my_key")) -``` \ No newline at end of file +if __name__ == "__main__": + asyncio.run(run()) diff --git a/packages/graphrag-storage/pyproject.toml b/packages/graphrag-storage/pyproject.toml index fb21d4ad2f..464189f950 100644 --- a/packages/graphrag-storage/pyproject.toml +++ b/packages/graphrag-storage/pyproject.toml @@ -31,10 +31,12 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ + "aiofiles>=24.1.0", "azure-cosmos>=4.9.0", "azure-identity>=1.19.0", "azure-storage-blob>=12.24.0", "graphrag-common==2.7.0", + "pandas>=2.2.3", "pydantic>=2.10.3", ] diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 83677a4764..bbd5e276f2 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -89,7 +89,7 @@ def test_register_and_create_custom_storage(): def test_create_unknown_storage(): with pytest.raises( ValueError, - match="StorageConfig\\.type 'unknown' is not registered in the StorageFactory\\. Registered types: FileStorage, MemoryStorage, AzureBlobStorage, AzureCosmosStorage, custom\\.", + match="StorageConfig\\.type 'unknown' is not registered in the StorageFactory\\.", ): create_storage(StorageConfig(type="unknown")) diff --git a/uv.lock b/uv.lock index 8bca344c84..8ec2e3830a 100644 --- a/uv.lock +++ b/uv.lock @@ -1168,19 +1168,23 @@ name = "graphrag-storage" version = "2.7.0" source = { editable = "packages/graphrag-storage" } dependencies = [ + { name = "aiofiles" }, { name = "azure-cosmos" }, { name = "azure-identity" }, { name = "azure-storage-blob" }, { name = "graphrag-common" }, + { name = "pandas" }, { name = "pydantic" }, ] [package.metadata] requires-dist = [ + { name = "aiofiles", specifier = ">=24.1.0" }, { name = "azure-cosmos", specifier = ">=4.9.0" }, { name = "azure-identity", specifier = ">=1.19.0" }, { name = "azure-storage-blob", specifier = ">=12.24.0" }, { name = "graphrag-common", editable = "packages/graphrag-common" }, + { name = "pandas", specifier = ">=2.2.3" }, { name = "pydantic", specifier = ">=2.10.3" }, ] From 9b059240d31b09a5827de592b18682f3ffb305f9 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Wed, 12 Nov 2025 08:19:08 -0800 Subject: [PATCH 04/17] Remove create_storage_from_config helper. --- .../graphrag_storage/storage.py | 75 ++++++++++++++----- packages/graphrag/graphrag/cli/query.py | 5 +- .../graphrag/index/run/run_pipeline.py | 10 +-- packages/graphrag/graphrag/index/run/utils.py | 7 +- .../graphrag/prompt_tune/loader/input.py | 4 +- packages/graphrag/graphrag/utils/api.py | 9 --- tests/unit/indexing/input/test_csv_loader.py | 11 ++- tests/unit/indexing/input/test_json_loader.py | 13 ++-- tests/unit/indexing/input/test_txt_loader.py | 9 +-- 9 files changed, 86 insertions(+), 57 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/storage.py b/packages/graphrag-storage/graphrag_storage/storage.py index c9fe400331..d8016d4ae7 100644 --- a/packages/graphrag-storage/graphrag_storage/storage.py +++ b/packages/graphrag-storage/graphrag_storage/storage.py @@ -22,7 +22,19 @@ def find( self, file_pattern: re.Pattern[str], ) -> Iterator[str]: - """Find files in the storage using a file pattern.""" + """Find files in the storage using a file pattern. + + Args + ---- + - file_pattern: re.Pattern[str] + The file pattern to use for finding files. + + Returns + ------- + Iterator[str]: + An iterator over the found file keys. + + """ @abstractmethod async def get( @@ -30,42 +42,56 @@ async def get( ) -> Any: """Get the value for the given key. - Args: - - key - The key to get the value for. - - as_bytes - Whether or not to return the value as bytes. + Args + ---- + - key: str + The key to get the value for. + - as_bytes: bool | None, optional (default=None) + Whether or not to return the value as bytes. + - encoding: str | None, optional (default=None) + The encoding to use when decoding the value. Returns ------- - - output - The value for the given key. + Any: + The value for the given key. """ @abstractmethod async def set(self, key: str, value: Any, encoding: str | None = None) -> None: """Set the value for the given key. - Args: - - key - The key to set the value for. - - value - The value to set. + Args + ---- + - key: str + The key to set the value for. + - value: Any + The value to set. """ @abstractmethod async def has(self, key: str) -> bool: """Return True if the given key exists in the storage. - Args: - - key - The key to check for. + Args + ---- + - key: str + The key to check for. Returns ------- - - output - True if the key exists in the storage, False otherwise. + bool: + True if the key exists in the storage, False otherwise. """ @abstractmethod async def delete(self, key: str) -> None: """Delete the given key from the storage. - Args: - - key - The key to delete. + Args + ---- + - key: str + The key to delete. """ @abstractmethod @@ -74,7 +100,19 @@ async def clear(self) -> None: @abstractmethod def child(self, name: str | None) -> "Storage": - """Create a child storage instance.""" + """Create a child storage instance. + + Args + ---- + - name: str | None + The name of the child storage. + + Returns + ------- + Storage + The child storage instance. + + """ @abstractmethod def keys(self) -> list[str]: @@ -84,12 +122,15 @@ def keys(self) -> list[str]: async def get_creation_date(self, key: str) -> str: """Get the creation date for the given key. - Args: - - key - The key to get the creation date for. + Args + ---- + - key: str + The key to get the creation date for. Returns ------- - - output - The creation date for the given key. + str: + The creation date for the given key. """ diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index 93163db19d..6ce049ddcc 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -8,11 +8,12 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from graphrag_storage import create_storage + import graphrag.api as api from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks from graphrag.config.load_config import load_config from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.utils.api import create_storage_from_config from graphrag.utils.storage import load_table_from_storage, storage_has_table if TYPE_CHECKING: @@ -376,7 +377,7 @@ def _resolve_output_files( ) -> dict[str, Any]: """Read indexing output files to a dataframe dict.""" dataframe_dict = {} - storage_obj = create_storage_from_config(config.output) + storage_obj = create_storage(config.output) for name in output_list: df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj)) dataframe_dict[name] = df_value diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index 5e87249550..06940822eb 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -12,7 +12,7 @@ from typing import Any import pandas as pd -from graphrag_storage import Storage +from graphrag_storage import Storage, create_storage from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -20,7 +20,7 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.utils.api import create_cache_from_config, create_storage_from_config +from graphrag.utils.api import create_cache_from_config from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -35,8 +35,8 @@ async def run_pipeline( input_documents: pd.DataFrame | None = None, ) -> AsyncIterable[PipelineRunResult]: """Run all workflows using a simplified pipeline.""" - input_storage = create_storage_from_config(config.input.storage) - output_storage = create_storage_from_config(config.output) + input_storage = create_storage(config.input.storage) + output_storage = create_storage(config.output) cache = create_cache_from_config(config.cache) # load existing state in case any workflows are stateful @@ -49,7 +49,7 @@ async def run_pipeline( if is_update_run: logger.info("Running incremental indexing.") - update_storage = create_storage_from_config(config.update_index_output) + update_storage = create_storage(config.update_index_output) # we use this to store the new subset index, and will merge its content with the previous index update_timestamp = time.strftime("%Y%m%d-%H%M%S") timestamped_storage = update_storage.child(update_timestamp) diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index 372023c879..03e789746a 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -3,7 +3,7 @@ """Utility functions for the GraphRAG run module.""" -from graphrag_storage import Storage +from graphrag_storage import Storage, create_storage from graphrag_storage.memory_storage import MemoryStorage from graphrag.cache.memory_pipeline_cache import InMemoryCache @@ -15,7 +15,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag.utils.api import create_storage_from_config def create_run_context( @@ -53,8 +52,8 @@ def get_update_storages( config: GraphRagConfig, timestamp: str ) -> tuple[Storage, Storage, Storage]: """Get storage objects for the update index run.""" - output_storage = create_storage_from_config(config.output) - update_storage = create_storage_from_config(config.update_index_output) + output_storage = create_storage(config.output) + update_storage = create_storage(config.update_index_output) timestamped_storage = update_storage.child(timestamp) delta_storage = timestamped_storage.child("delta") previous_storage = timestamped_storage.child("previous") diff --git a/packages/graphrag/graphrag/prompt_tune/loader/input.py b/packages/graphrag/graphrag/prompt_tune/loader/input.py index 5e9fccb440..c810b0ce41 100644 --- a/packages/graphrag/graphrag/prompt_tune/loader/input.py +++ b/packages/graphrag/graphrag/prompt_tune/loader/input.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from graphrag_storage import create_storage from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks @@ -25,7 +26,6 @@ ) from graphrag.prompt_tune.types import DocSelectionType from graphrag.tokenizer.get_tokenizer import get_tokenizer -from graphrag.utils.api import create_storage_from_config def _sample_chunks_from_embeddings( @@ -63,7 +63,7 @@ async def load_docs_in_chunks( cache=NoopPipelineCache(), ) tokenizer = get_tokenizer(embeddings_llm_settings) - input_storage = create_storage_from_config(config.input.storage) + input_storage = create_storage(config.input.storage) input_reader = InputReaderFactory().create( config.input.file_type, {"storage": input_storage, "config": config.input}, diff --git a/packages/graphrag/graphrag/utils/api.py b/packages/graphrag/graphrag/utils/api.py index 82d86bb544..2d83d692ff 100644 --- a/packages/graphrag/graphrag/utils/api.py +++ b/packages/graphrag/graphrag/utils/api.py @@ -6,8 +6,6 @@ from pathlib import Path from typing import Any -from graphrag_storage import Storage, StorageConfig, create_storage - from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.embeddings import create_index_name @@ -100,13 +98,6 @@ def load_search_prompt(prompt_config: str | None) -> str | None: return None -def create_storage_from_config(output: StorageConfig) -> Storage: - """Create a storage object from the config.""" - return create_storage( - output, - ) - - def create_cache_from_config(cache: CacheConfig) -> PipelineCache: """Create a cache object from the config.""" cache_config = cache.model_dump() diff --git a/tests/unit/indexing/input/test_csv_loader.py b/tests/unit/indexing/input/test_csv_loader.py index 72a33ff749..b0dc645e1b 100644 --- a/tests/unit/indexing/input/test_csv_loader.py +++ b/tests/unit/indexing/input/test_csv_loader.py @@ -4,8 +4,7 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig from graphrag.index.input.factory import InputReaderFactory -from graphrag.utils.api import create_storage_from_config -from graphrag_storage import StorageConfig +from graphrag_storage import StorageConfig, create_storage async def test_csv_loader_one_file(): @@ -16,7 +15,7 @@ async def test_csv_loader_one_file(): file_type=InputFileType.csv, file_pattern=".*\\.csv$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -35,7 +34,7 @@ async def test_csv_loader_one_file_with_title(): file_pattern=".*\\.csv$", title_column="title", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -55,7 +54,7 @@ async def test_csv_loader_one_file_with_metadata(): title_column="title", metadata=["title"], ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -74,7 +73,7 @@ async def test_csv_loader_multiple_files(): file_type=InputFileType.csv, file_pattern=".*\\.csv$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) diff --git a/tests/unit/indexing/input/test_json_loader.py b/tests/unit/indexing/input/test_json_loader.py index bffbe6a630..3959096b1c 100644 --- a/tests/unit/indexing/input/test_json_loader.py +++ b/tests/unit/indexing/input/test_json_loader.py @@ -4,8 +4,7 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig from graphrag.index.input.factory import InputReaderFactory -from graphrag.utils.api import create_storage_from_config -from graphrag_storage import StorageConfig +from graphrag_storage import StorageConfig, create_storage async def test_json_loader_one_file_one_object(): @@ -16,7 +15,7 @@ async def test_json_loader_one_file_one_object(): file_type=InputFileType.json, file_pattern=".*\\.json$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -34,7 +33,7 @@ async def test_json_loader_one_file_multiple_objects(): file_type=InputFileType.json, file_pattern=".*\\.json$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -54,7 +53,7 @@ async def test_json_loader_one_file_with_title(): file_pattern=".*\\.json$", title_column="title", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -74,7 +73,7 @@ async def test_json_loader_one_file_with_metadata(): title_column="title", metadata=["title"], ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -92,7 +91,7 @@ async def test_json_loader_multiple_files(): file_type=InputFileType.json, file_pattern=".*\\.json$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) diff --git a/tests/unit/indexing/input/test_txt_loader.py b/tests/unit/indexing/input/test_txt_loader.py index 9133fbd91e..57ded82507 100644 --- a/tests/unit/indexing/input/test_txt_loader.py +++ b/tests/unit/indexing/input/test_txt_loader.py @@ -4,8 +4,7 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig from graphrag.index.input.factory import InputReaderFactory -from graphrag.utils.api import create_storage_from_config -from graphrag_storage import StorageConfig +from graphrag_storage import StorageConfig, create_storage async def test_txt_loader_one_file(): @@ -16,7 +15,7 @@ async def test_txt_loader_one_file(): file_type=InputFileType.text, file_pattern=".*\\.txt$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -35,7 +34,7 @@ async def test_txt_loader_one_file_with_metadata(): file_pattern=".*\\.txt$", metadata=["title"], ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -54,7 +53,7 @@ async def test_txt_loader_multiple_files(): file_type=InputFileType.text, file_pattern=".*\\.txt$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) From cd87faaa59d8b0310518fb4301bd440b16c5db60 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 2 Dec 2025 09:55:27 -0800 Subject: [PATCH 05/17] update factory method for dynamic imports and use enums for factory registration --- docs/config/yaml.md | 6 +- packages/graphrag-storage/README.md | 5 +- .../graphrag_storage/__init__.py | 7 ++- .../graphrag_storage/azure_cosmos_storage.py | 3 +- .../graphrag_storage/storage_config.py | 4 +- .../graphrag_storage/storage_factory.py | 57 ++++++++++++------- .../graphrag_storage/storage_type.py | 16 ++++++ packages/graphrag/graphrag/config/defaults.py | 4 +- .../graphrag/graphrag/config/init_content.py | 4 +- .../config/models/graph_rag_config.py | 9 ++- tests/integration/storage/test_factory.py | 16 ++++-- 11 files changed, 86 insertions(+), 45 deletions(-) create mode 100644 packages/graphrag-storage/graphrag_storage/storage_type.py diff --git a/docs/config/yaml.md b/docs/config/yaml.md index 9a45e2f313..19c1cfd05e 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -81,7 +81,7 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th #### Fields - `storage` **StorageConfig** - - `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` + - `type` **file|azure_blob|azure_cosmos** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -115,7 +115,7 @@ This section controls the storage mechanism used by the pipeline used for export #### Fields -- `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` +- `type` **file|azure_blob|azure_cosmos** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -128,7 +128,7 @@ The section defines a secondary storage location for running incremental indexin #### Fields -- `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` +- `type` **file|azure_blob|azure_cosmos** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md index cd31bb0658..37243ebc04 100644 --- a/packages/graphrag-storage/README.md +++ b/packages/graphrag-storage/README.md @@ -4,13 +4,12 @@ ```python import asyncio -from graphrag_storage import StorageConfig, create_storage -from graphrag_storage.file_storage import FileStorage +from graphrag_storage import StorageConfig, create_storage, StorageType async def run(): storage = create_storage( StorageConfig( - type="FileStorage", # or FileStorage.__name__ + type=StorageType.FILE base_dir="output" ) ) diff --git a/packages/graphrag-storage/graphrag_storage/__init__.py b/packages/graphrag-storage/graphrag_storage/__init__.py index 0684dfb889..2ae67be741 100644 --- a/packages/graphrag-storage/graphrag_storage/__init__.py +++ b/packages/graphrag-storage/graphrag_storage/__init__.py @@ -5,11 +5,16 @@ from graphrag_storage.storage import Storage from graphrag_storage.storage_config import StorageConfig -from graphrag_storage.storage_factory import create_storage, register_storage +from graphrag_storage.storage_factory import ( + create_storage, + register_storage, +) +from graphrag_storage.storage_type import StorageType __all__ = [ "Storage", "StorageConfig", + "StorageType", "create_storage", "register_storage", ] diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index 4e4e034eb7..edad9495ec 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -45,6 +45,7 @@ def __init__( container_name: str | None = None, connection_string: str | None = None, cosmosdb_account_url: str | None = None, + encoding: str = "utf-8", **kwargs: Any, ) -> None: """Create a CosmosDB storage instance.""" @@ -77,7 +78,7 @@ def __init__( url=cosmosdb_account_url, credential=DefaultAzureCredential(), ) - self._encoding = kwargs.get("encoding", "utf-8") + self._encoding = encoding self._database_name = database_name self._connection_string = connection_string self._cosmosdb_account_url = cosmosdb_account_url diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py index 0a8cf76893..83216f6a2e 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_config.py +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -5,6 +5,8 @@ from pydantic import BaseModel, ConfigDict, Field +from graphrag_storage.storage_type import StorageType + class StorageConfig(BaseModel): """The default configuration section for storage.""" @@ -14,7 +16,7 @@ class StorageConfig(BaseModel): type: str = Field( description="The storage type to use.", - default="FileStorage", + default=StorageType.FILE, ) base_dir: str | None = Field( diff --git a/packages/graphrag-storage/graphrag_storage/storage_factory.py b/packages/graphrag-storage/graphrag_storage/storage_factory.py index d1ab2f4db0..caea8db15d 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_factory.py +++ b/packages/graphrag-storage/graphrag_storage/storage_factory.py @@ -8,42 +8,31 @@ from graphrag_common.factory import Factory -from graphrag_storage.azure_blob_storage import AzureBlobStorage -from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage -from graphrag_storage.file_storage import FileStorage -from graphrag_storage.memory_storage import MemoryStorage from graphrag_storage.storage import Storage from graphrag_storage.storage_config import StorageConfig +from graphrag_storage.storage_type import StorageType class _StorageFactory(Factory[Storage]): - """A factory class for storage implementations. - - Includes a method for users to register a custom storage implementation. - - Configuration arguments are passed to each storage implementation as kwargs - for individual enforcement of required/optional arguments. - """ + """A factory class for storage implementations.""" storage_factory = _StorageFactory() -storage_factory.register(FileStorage.__name__, FileStorage) -storage_factory.register(MemoryStorage.__name__, MemoryStorage) -storage_factory.register(AzureBlobStorage.__name__, AzureBlobStorage) -storage_factory.register(AzureCosmosStorage.__name__, AzureCosmosStorage) -def register_storage(storage: str, storage_initializer: Callable[..., Storage]) -> None: +def register_storage( + storage_type: str, storage_initializer: Callable[..., Storage] +) -> None: """Register a custom storage implementation. Args ---- - - storage: str + - storage_type: str The storage id to register. - storage_initializer: Callable[..., Storage] The storage initializer to register. """ - storage_factory.register(storage, storage_initializer) + storage_factory.register(storage_type, storage_initializer) def create_storage(config: StorageConfig) -> Storage: @@ -59,10 +48,34 @@ def create_storage(config: StorageConfig) -> Storage: Storage The created storage implementation. """ - storage_strategy = config.type + config_model = config.model_dump() + storage_strategy = config_model.pop("type") + + # Check storage_strategy is a string + if not isinstance(storage_strategy, str): + msg = f"StorageConfig.type must be a string, got {type(storage_strategy)}" + raise TypeError(msg) if storage_strategy not in storage_factory: - msg = f"StorageConfig.type '{storage_strategy}' is not registered in the StorageFactory. Registered types: {', '.join(storage_factory.keys())}." - raise ValueError(msg) + match storage_strategy: + case StorageType.FILE: + from graphrag_storage.file_storage import FileStorage + + register_storage(StorageType.FILE, FileStorage) + case StorageType.MEMORY: + from graphrag_storage.memory_storage import MemoryStorage + + register_storage(StorageType.MEMORY, MemoryStorage) + case StorageType.AZURE_BLOB: + from graphrag_storage.azure_blob_storage import AzureBlobStorage + + register_storage(StorageType.AZURE_BLOB, AzureBlobStorage) + case StorageType.AZURE_COSMOS: + from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage + + register_storage(StorageType.AZURE_COSMOS, AzureCosmosStorage) + case _: + msg = f"StorageConfig.type '{storage_strategy}' is not registered in the StorageFactory. Registered types: {', '.join(storage_factory.keys())}." + raise ValueError(msg) - return storage_factory.create(config.type, config.model_dump()) + return storage_factory.create(storage_strategy, config_model) diff --git a/packages/graphrag-storage/graphrag_storage/storage_type.py b/packages/graphrag-storage/graphrag_storage/storage_type.py new file mode 100644 index 0000000000..89a59b00d8 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/storage_type.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + + +"""Builtin storage implementation types.""" + +from enum import StrEnum + + +class StorageType(StrEnum): + """Enum for storage types.""" + + FILE = "file" + MEMORY = "memory" + AZURE_BLOB = "azure_blob" + AZURE_COSMOS = "azure_cosmos" diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 7e2284ec2a..ea45fb7381 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import ClassVar -from graphrag_storage.file_storage import FileStorage +from graphrag_storage import StorageType from graphrag.config.embeddings import default_embeddings from graphrag.config.enums import ( @@ -231,7 +231,7 @@ class GlobalSearchDefaults: class StorageDefaults: """Default values for storage.""" - type: str = FileStorage.__name__ + type: str = StorageType.FILE base_dir: str | None = None connection_string: None = None container_name: None = None diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index 04787f9a43..59025a3259 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -50,7 +50,7 @@ input: storage: - type: {graphrag_config_defaults.input.storage.type} # or AzureBlobStorage, AzureCosmosStorage + type: {graphrag_config_defaults.input.storage.type} # or azure_blob, azure_cosmos base_dir: "{graphrag_config_defaults.input.storage.base_dir}" file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] @@ -63,7 +63,7 @@ ## connection_string and container_name must be provided output: - type: {graphrag_config_defaults.output.type} # or AzureBlobStorage, AzureCosmosStorage + type: {graphrag_config_defaults.output.type} # or azure_blob, azure_cosmos base_dir: "{graphrag_config_defaults.output.base_dir}" cache: diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index db8bdfae43..d19f8f733d 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -6,8 +6,7 @@ from pathlib import Path from devtools import pformat -from graphrag_storage import StorageConfig -from graphrag_storage.file_storage import FileStorage +from graphrag_storage import StorageConfig, StorageType from pydantic import BaseModel, Field, model_validator import graphrag.config.defaults as defs @@ -117,7 +116,7 @@ def _validate_input_pattern(self) -> None: def _validate_input_base_dir(self) -> None: """Validate the input base directory.""" - if self.input.storage.type == FileStorage.__name__: + if self.input.storage.type == StorageType.FILE: if not self.input.storage.base_dir: msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." raise ValueError(msg) @@ -141,7 +140,7 @@ def _validate_input_base_dir(self) -> None: def _validate_output_base_dir(self) -> None: """Validate the output base directory.""" - if self.output.type == FileStorage.__name__: + if self.output.type == StorageType.FILE: if not self.output.base_dir: msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) @@ -157,7 +156,7 @@ def _validate_output_base_dir(self) -> None: def _validate_update_index_output_base_dir(self) -> None: """Validate the update index output base directory.""" - if self.update_index_output.type == FileStorage.__name__: + if self.update_index_output.type == StorageType.FILE: if not self.update_index_output.base_dir: msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index bbd5e276f2..2b06d6e711 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -8,7 +8,13 @@ import sys import pytest -from graphrag_storage import Storage, StorageConfig, create_storage, register_storage +from graphrag_storage import ( + Storage, + StorageConfig, + StorageType, + create_storage, + register_storage, +) from graphrag_storage.azure_blob_storage import AzureBlobStorage from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage from graphrag_storage.file_storage import FileStorage @@ -23,7 +29,7 @@ @pytest.mark.skip(reason="Blob storage emulator is not available in this environment") def test_create_blob_storage(): config = StorageConfig( - type=AzureBlobStorage.__name__, + type=StorageType.AZURE_BLOB, connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, base_dir="testbasedir", container_name="testcontainer", @@ -38,7 +44,7 @@ def test_create_blob_storage(): ) def test_create_cosmosdb_storage(): config = StorageConfig( - type=AzureCosmosStorage.__name__, + type=StorageType.AZURE_COSMOS, connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testdatabase", container_name="testcontainer", @@ -49,7 +55,7 @@ def test_create_cosmosdb_storage(): def test_create_file(): config = StorageConfig( - type=FileStorage.__name__, + type=StorageType.FILE, base_dir="/tmp/teststorage", ) storage = create_storage(config) @@ -59,7 +65,7 @@ def test_create_file(): def test_create_memory_storage(): config = StorageConfig( base_dir="", - type=MemoryStorage.__name__, + type=StorageType.MEMORY, ) storage = create_storage(config) assert isinstance(storage, MemoryStorage) From 37c031656fb243bef667e58b58b9252b52f41c8f Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 2 Dec 2025 10:59:51 -0800 Subject: [PATCH 06/17] Add encoding to storage config --- packages/graphrag-common/graphrag_common/factory/factory.py | 5 ++++- packages/graphrag-storage/graphrag_storage/file_storage.py | 2 +- packages/graphrag-storage/graphrag_storage/memory_storage.py | 4 ++++ packages/graphrag-storage/graphrag_storage/storage_config.py | 5 +++++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/packages/graphrag-common/graphrag_common/factory/factory.py b/packages/graphrag-common/graphrag_common/factory/factory.py index d9aa236cc4..8f57606aa3 100644 --- a/packages/graphrag-common/graphrag_common/factory/factory.py +++ b/packages/graphrag-common/graphrag_common/factory/factory.py @@ -84,11 +84,14 @@ def create(self, strategy: str, init_args: dict[str, Any] | None = None) -> T: msg = f"Strategy '{strategy}' is not registered. Registered strategies are: {', '.join(list(self._service_initializers.keys()))}" raise ValueError(msg) + # Delete entries with value None + init_args = {k: v for k, v in (init_args or {}).items() if v is not None} + service_descriptor = self._service_initializers[strategy] if service_descriptor.scope == "singleton": if strategy not in self._initialized_services: self._initialized_services[strategy] = service_descriptor.initializer( - **(init_args or {}) + **init_args ) return self._initialized_services[strategy] diff --git a/packages/graphrag-storage/graphrag_storage/file_storage.py b/packages/graphrag-storage/graphrag_storage/file_storage.py index 61cb922ec6..c6cf4684be 100644 --- a/packages/graphrag-storage/graphrag_storage/file_storage.py +++ b/packages/graphrag-storage/graphrag_storage/file_storage.py @@ -31,7 +31,7 @@ class FileStorage(Storage): _encoding: str def __init__( - self, base_dir: str | None = "", encoding: str = "utf-8", **kwargs: Any + self, base_dir: str | None = None, encoding: str = "utf-8", **kwargs: Any ) -> None: """Create a file based storage.""" if base_dir is None: diff --git a/packages/graphrag-storage/graphrag_storage/memory_storage.py b/packages/graphrag-storage/graphrag_storage/memory_storage.py index 7908d98a35..d79d8a1b7e 100644 --- a/packages/graphrag-storage/graphrag_storage/memory_storage.py +++ b/packages/graphrag-storage/graphrag_storage/memory_storage.py @@ -18,6 +18,10 @@ class MemoryStorage(FileStorage): def __init__(self, **kwargs: Any) -> None: """Init method definition.""" + kwargs = { + "base_dir": "", + **kwargs, + } super().__init__(**kwargs) self._storage = {} diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py index 83216f6a2e..2f53614df3 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_config.py +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -19,6 +19,11 @@ class StorageConfig(BaseModel): default=StorageType.FILE, ) + encoding: str | None = Field( + description="The encoding to use for file storage.", + default=None, + ) + base_dir: str | None = Field( description="The base directory for the output.", default=None, From 4ecf218d7f22f6b6a8247f76be38285676a62465 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 2 Dec 2025 11:21:49 -0800 Subject: [PATCH 07/17] cleanup blob container name validation --- .../graphrag_storage/azure_blob_storage.py | 26 +++---------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index 9028259bf8..66eb14f7a6 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -277,27 +277,7 @@ def _validate_blob_container_name(container_name: str) -> None: ------- bool: True if valid, False otherwise. """ - # Check the length of the name - if len(container_name) < 3 or len(container_name) > 63: - msg = f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long." - raise ValueError(msg) - - # Check if the name starts with a letter or number - if not container_name[0].isalnum(): - msg = f"Container name must start with a letter or number. Starting character was {container_name[0]}." - raise ValueError(msg) - - # Check for valid characters (letters, numbers, hyphen) and lowercase letters - if not re.match(r"^[a-z0-9-]+$", container_name): - msg = f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}." - raise ValueError(msg) - - # Check for consecutive hyphens - if "--" in container_name: - msg = f"Container name cannot contain consecutive hyphens. Name provided was {container_name}." - raise ValueError(msg) - - # Check for hyphens at the end of the name - if container_name[-1] == "-": - msg = f"Container name cannot end with a hyphen. Name provided was {container_name}." + # Match alphanumeric or single hyphen not at the start or end, repeated 3-63 times. + if not re.match(r"^(?:[0-9a-z]|(? Date: Tue, 2 Dec 2025 11:52:08 -0800 Subject: [PATCH 08/17] Remove using kwargs to swallow unknown factory config parameters. Unknown config parameters for a given implementation will throw an exception. --- packages/graphrag-storage/README.md | 2 +- .../graphrag_storage/azure_blob_storage.py | 22 +++++++------------ .../graphrag_storage/azure_cosmos_storage.py | 20 ++++++----------- .../graphrag_storage/file_storage.py | 9 +------- .../graphrag_storage/memory_storage.py | 1 + .../graphrag_storage/storage.py | 4 ---- packages/graphrag/graphrag/cache/factory.py | 3 +++ .../integration/storage/test_file_storage.py | 2 +- 8 files changed, 22 insertions(+), 41 deletions(-) diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md index 37243ebc04..ffda7441fb 100644 --- a/packages/graphrag-storage/README.md +++ b/packages/graphrag-storage/README.md @@ -29,7 +29,7 @@ from typing import Any from graphrag_storage import Storage, StorageConfig, create_storage, register_storage class MyStorage(Storage): - def __init__(self, some_setting: str, **kwargs: Any): + def __init__(self, some_setting: str): # Validate settings and initialize ... diff --git a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index 66eb14f7a6..daed579f7e 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -33,29 +33,18 @@ class AzureBlobStorage(Storage): def __init__( self, - base_dir: str | None = None, - connection_string: str | None = None, + container_name: str, storage_account_blob_url: str | None = None, - container_name: str | None = None, + connection_string: str | None = None, + base_dir: str | None = None, encoding: str = "utf-8", - **kwargs: Any, ) -> None: """Create a new BlobStorage instance.""" - if connection_string is None and storage_account_blob_url is None: - msg = "AzureBlobStorage requires either a connection_string or storage_account_blob_url to be specified." - logger.error(msg) - raise ValueError(msg) - if connection_string is not None and storage_account_blob_url is not None: msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both." logger.error(msg) raise ValueError(msg) - if container_name is None: - msg = "AzureBlobStorage requires a container_name to be specified." - logger.error(msg) - raise ValueError(msg) - _validate_blob_container_name(container_name) logger.info( @@ -70,6 +59,11 @@ def __init__( account_url=storage_account_blob_url, credential=DefaultAzureCredential(), ) + else: + msg = "AzureBlobStorage requires either a connection_string or storage_account_blob_url to be specified." + logger.error(msg) + raise ValueError(msg) + self._encoding = encoding self._container_name = container_name self._connection_string = connection_string diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index edad9495ec..d619942761 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -41,12 +41,11 @@ class AzureCosmosStorage(Storage): def __init__( self, - base_dir: str | None = None, - container_name: str | None = None, + base_dir: str, + container_name: str, connection_string: str | None = None, cosmosdb_account_url: str | None = None, encoding: str = "utf-8", - **kwargs: Any, ) -> None: """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") @@ -56,21 +55,11 @@ def __init__( logger.error(msg) raise ValueError(msg) - if connection_string is None and cosmosdb_account_url is None: - msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified." - logger.error(msg) - raise ValueError(msg) - if connection_string is not None and cosmosdb_account_url is not None: msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified, not both." logger.error(msg) raise ValueError(msg) - if container_name is None: - msg = "CosmosDB Storage requires a container_name to be specified." - logger.error(msg) - raise ValueError(msg) - if connection_string: self._cosmos_client = CosmosClient.from_connection_string(connection_string) elif cosmosdb_account_url: @@ -78,6 +67,11 @@ def __init__( url=cosmosdb_account_url, credential=DefaultAzureCredential(), ) + else: + msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified." + logger.error(msg) + raise ValueError(msg) + self._encoding = encoding self._database_name = database_name self._connection_string = connection_string diff --git a/packages/graphrag-storage/graphrag_storage/file_storage.py b/packages/graphrag-storage/graphrag_storage/file_storage.py index c6cf4684be..6c9f845734 100644 --- a/packages/graphrag-storage/graphrag_storage/file_storage.py +++ b/packages/graphrag-storage/graphrag_storage/file_storage.py @@ -30,15 +30,8 @@ class FileStorage(Storage): _base_dir: Path _encoding: str - def __init__( - self, base_dir: str | None = None, encoding: str = "utf-8", **kwargs: Any - ) -> None: + def __init__(self, base_dir: str, encoding: str = "utf-8") -> None: """Create a file based storage.""" - if base_dir is None: - msg = "FileStorage requires a base_dir to be specified." - logger.error(msg) - raise ValueError(msg) - self._base_dir = Path(base_dir).resolve() self._encoding = encoding logger.info("Creating file storage at [%s]", self._base_dir) diff --git a/packages/graphrag-storage/graphrag_storage/memory_storage.py b/packages/graphrag-storage/graphrag_storage/memory_storage.py index d79d8a1b7e..f92a52a204 100644 --- a/packages/graphrag-storage/graphrag_storage/memory_storage.py +++ b/packages/graphrag-storage/graphrag_storage/memory_storage.py @@ -22,6 +22,7 @@ def __init__(self, **kwargs: Any) -> None: "base_dir": "", **kwargs, } + kwargs.pop("type", None) super().__init__(**kwargs) self._storage = {} diff --git a/packages/graphrag-storage/graphrag_storage/storage.py b/packages/graphrag-storage/graphrag_storage/storage.py index d8016d4ae7..e356af2d64 100644 --- a/packages/graphrag-storage/graphrag_storage/storage.py +++ b/packages/graphrag-storage/graphrag_storage/storage.py @@ -13,10 +13,6 @@ class Storage(ABC): """Provide a storage interface.""" - @abstractmethod - def __init__(self, **kwargs: Any) -> None: - """Create a storage instance.""" - @abstractmethod def find( self, diff --git a/packages/graphrag/graphrag/cache/factory.py b/packages/graphrag/graphrag/cache/factory.py index ccbf1e200f..2660d0fea2 100644 --- a/packages/graphrag/graphrag/cache/factory.py +++ b/packages/graphrag/graphrag/cache/factory.py @@ -30,18 +30,21 @@ class CacheFactory(Factory[PipelineCache]): # --- register built-in cache implementations --- def create_file_cache(**kwargs) -> PipelineCache: """Create a file-based cache implementation.""" + kwargs.pop("type", None) storage = FileStorage(**kwargs) return JsonPipelineCache(storage) def create_blob_cache(**kwargs) -> PipelineCache: """Create a blob storage-based cache implementation.""" + kwargs.pop("type", None) storage = AzureBlobStorage(**kwargs) return JsonPipelineCache(storage) def create_cosmosdb_cache(**kwargs) -> PipelineCache: """Create a CosmosDB-based cache implementation.""" + kwargs.pop("type", None) storage = AzureCosmosStorage(**kwargs) return JsonPipelineCache(storage) diff --git a/tests/integration/storage/test_file_storage.py b/tests/integration/storage/test_file_storage.py index b6edc77b03..0852cd99fb 100644 --- a/tests/integration/storage/test_file_storage.py +++ b/tests/integration/storage/test_file_storage.py @@ -43,7 +43,7 @@ async def test_get_creation_date(): async def test_child(): - storage = FileStorage() + storage = FileStorage(base_dir="") storage = storage.child("tests/fixtures/text/input") items = list(storage.find(re.compile(r".*\.txt$"))) assert items == [str(Path("dulce.txt"))] From ecc4c772418a75fab7e8b4a92d76c62a96577c0a Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 2 Dec 2025 12:09:12 -0800 Subject: [PATCH 09/17] fix integration tests. --- packages/graphrag/graphrag/config/defaults.py | 1 + tests/integration/storage/test_blob_storage.py | 15 --------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index ea45fb7381..62597389f7 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -232,6 +232,7 @@ class StorageDefaults: """Default values for storage.""" type: str = StorageType.FILE + encoding: str | None = None base_dir: str | None = None connection_string: None = None container_name: None = None diff --git a/tests/integration/storage/test_blob_storage.py b/tests/integration/storage/test_blob_storage.py index 44216700da..ec996e91a2 100644 --- a/tests/integration/storage/test_blob_storage.py +++ b/tests/integration/storage/test_blob_storage.py @@ -41,25 +41,10 @@ async def test_find(): storage._delete_container() # noqa: SLF001 -async def test_dotprefix(): - storage = AzureBlobStorage( - connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, - container_name="testfind", - path_prefix=".", - ) - try: - await storage.set("input/christmas.txt", "Merry Christmas!", encoding="utf-8") - items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) - assert items == ["input/christmas.txt"] - finally: - storage._delete_container() # noqa: SLF001 - - async def test_get_creation_date(): storage = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testfind", - path_prefix=".", ) try: await storage.set("input/christmas.txt", "Merry Christmas!", encoding="utf-8") From 287b6b1cef6a80d02df949630b2b95cd0b732048 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 2 Dec 2025 14:01:52 -0800 Subject: [PATCH 10/17] cleanup storage config for handling azure services. --- .../graphrag_storage/azure_blob_storage.py | 41 +++++++++++-------- .../graphrag_storage/azure_cosmos_storage.py | 35 +++++++++------- .../graphrag_storage/storage_config.py | 24 ++++++----- packages/graphrag/graphrag/config/defaults.py | 8 ++-- .../integration/storage/test_blob_storage.py | 12 +++--- .../storage/test_cosmosdb_storage.py | 24 +++++------ tests/integration/storage/test_factory.py | 10 ++--- tests/smoke/test_fixtures.py | 4 +- tests/unit/config/utils.py | 31 ++++++++------ 9 files changed, 105 insertions(+), 84 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index daed579f7e..bc334ed84d 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -33,30 +33,35 @@ class AzureBlobStorage(Storage): def __init__( self, - container_name: str, - storage_account_blob_url: str | None = None, - connection_string: str | None = None, + azure_container_name: str, + azure_storage_account_blob_url: str | None = None, + azure_connection_string: str | None = None, base_dir: str | None = None, encoding: str = "utf-8", ) -> None: """Create a new BlobStorage instance.""" - if connection_string is not None and storage_account_blob_url is not None: + if ( + azure_connection_string is not None + and azure_storage_account_blob_url is not None + ): msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both." logger.error(msg) raise ValueError(msg) - _validate_blob_container_name(container_name) + _validate_blob_container_name(azure_container_name) logger.info( - "Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir + "Creating blob storage at [%s] and base_dir [%s]", + azure_container_name, + base_dir, ) - if connection_string: + if azure_connection_string: self._blob_service_client = BlobServiceClient.from_connection_string( - connection_string + azure_connection_string ) - elif storage_account_blob_url: + elif azure_storage_account_blob_url: self._blob_service_client = BlobServiceClient( - account_url=storage_account_blob_url, + account_url=azure_storage_account_blob_url, credential=DefaultAzureCredential(), ) else: @@ -65,13 +70,13 @@ def __init__( raise ValueError(msg) self._encoding = encoding - self._container_name = container_name - self._connection_string = connection_string + self._container_name = azure_container_name + self._connection_string = azure_connection_string self._base_dir = base_dir - self._storage_account_blob_url = storage_account_blob_url + self._storage_account_blob_url = azure_storage_account_blob_url self._storage_account_name = ( - storage_account_blob_url.split("//")[1].split(".")[0] - if storage_account_blob_url + azure_storage_account_blob_url.split("//")[1].split(".")[0] + if azure_storage_account_blob_url else None ) self._create_container() @@ -220,11 +225,11 @@ def child(self, name: str | None) -> "Storage": return self path = str(Path(self._base_dir) / name) if self._base_dir else name return AzureBlobStorage( - connection_string=self._connection_string, - container_name=self._container_name, + azure_connection_string=self._connection_string, + azure_container_name=self._container_name, encoding=self._encoding, base_dir=path, - storage_account_blob_url=self._storage_account_blob_url, + azure_storage_account_blob_url=self._storage_account_blob_url, ) def keys(self) -> list[str]: diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index d619942761..04d2dd4de8 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -41,30 +41,35 @@ class AzureCosmosStorage(Storage): def __init__( self, - base_dir: str, - container_name: str, - connection_string: str | None = None, - cosmosdb_account_url: str | None = None, + azure_cosmosdb_database_name: str, + azure_container_name: str, + azure_connection_string: str | None = None, + azure_cosmosdb_account_url: str | None = None, encoding: str = "utf-8", ) -> None: """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") - database_name = base_dir + database_name = azure_cosmosdb_database_name if database_name is None: msg = "CosmosDB Storage requires a base_dir to be specified. This is used as the database name." logger.error(msg) raise ValueError(msg) - if connection_string is not None and cosmosdb_account_url is not None: + if ( + azure_connection_string is not None + and azure_cosmosdb_account_url is not None + ): msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified, not both." logger.error(msg) raise ValueError(msg) - if connection_string: - self._cosmos_client = CosmosClient.from_connection_string(connection_string) - elif cosmosdb_account_url: + if azure_connection_string: + self._cosmos_client = CosmosClient.from_connection_string( + azure_connection_string + ) + elif azure_cosmosdb_account_url: self._cosmos_client = CosmosClient( - url=cosmosdb_account_url, + url=azure_cosmosdb_account_url, credential=DefaultAzureCredential(), ) else: @@ -74,12 +79,12 @@ def __init__( self._encoding = encoding self._database_name = database_name - self._connection_string = connection_string - self._cosmosdb_account_url = cosmosdb_account_url - self._container_name = container_name + self._connection_string = azure_connection_string + self._cosmosdb_account_url = azure_cosmosdb_account_url + self._container_name = azure_container_name self._cosmosdb_account_name = ( - cosmosdb_account_url.split("//")[1].split(".")[0] - if cosmosdb_account_url + azure_cosmosdb_account_url.split("//")[1].split(".")[0] + if azure_cosmosdb_account_url else None ) self._no_id_prefixes = [] diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py index 2f53614df3..b0e2253f20 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_config.py +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -15,7 +15,7 @@ class StorageConfig(BaseModel): """Allow extra fields to support custom storage implementations.""" type: str = Field( - description="The storage type to use.", + description="The storage type to use. Builtin types include 'file', 'azure_blob', and 'azure_cosmos'.", default=StorageType.FILE, ) @@ -25,24 +25,28 @@ class StorageConfig(BaseModel): ) base_dir: str | None = Field( - description="The base directory for the output.", + description="The base directory for the output when using file or azure_blob storage.", default=None, ) - connection_string: str | None = Field( - description="The storage connection string to use.", + azure_connection_string: str | None = Field( + description="The connection string for Azure Blob Storage or Azure CosmosDB.", default=None, ) - container_name: str | None = Field( - description="The storage container name to use.", + azure_container_name: str | None = Field( + description="The Azure Blob Storage container name or CosmosDB container name to use.", default=None, ) - storage_account_blob_url: str | None = Field( - description="The storage account blob url to use.", + azure_storage_account_blob_url: str | None = Field( + description="The Azure Blob Storage account blob url to use.", default=None, ) - cosmosdb_account_url: str | None = Field( - description="The cosmosdb account url to use.", + azure_cosmosdb_database_name: str | None = Field( + description="The Azure CosmosDB database name to use.", + default=None, + ) + azure_cosmosdb_account_url: str | None = Field( + description="The Azure CosmosDB account url to use.", default=None, ) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 62597389f7..e70f321a04 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -234,10 +234,10 @@ class StorageDefaults: type: str = StorageType.FILE encoding: str | None = None base_dir: str | None = None - connection_string: None = None - container_name: None = None - storage_account_blob_url: None = None - cosmosdb_account_url: None = None + azure_connection_string: None = None + azure_container_name: None = None + azure_storage_account_blob_url: None = None + azure_cosmosdb_account_url: None = None @dataclass diff --git a/tests/integration/storage/test_blob_storage.py b/tests/integration/storage/test_blob_storage.py index ec996e91a2..9b1654aa70 100644 --- a/tests/integration/storage/test_blob_storage.py +++ b/tests/integration/storage/test_blob_storage.py @@ -13,8 +13,8 @@ async def test_find(): storage = AzureBlobStorage( - connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, - container_name="testfind", + azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + azure_container_name="testfind", ) try: try: @@ -43,8 +43,8 @@ async def test_find(): async def test_get_creation_date(): storage = AzureBlobStorage( - connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, - container_name="testfind", + azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + azure_container_name="testfind", ) try: await storage.set("input/christmas.txt", "Merry Christmas!", encoding="utf-8") @@ -60,8 +60,8 @@ async def test_get_creation_date(): async def test_child(): parent = AzureBlobStorage( - connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, - container_name="testchild", + azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + azure_container_name="testchild", ) try: try: diff --git a/tests/integration/storage/test_cosmosdb_storage.py b/tests/integration/storage/test_cosmosdb_storage.py index 9f85d93e0f..5db0d0898f 100644 --- a/tests/integration/storage/test_cosmosdb_storage.py +++ b/tests/integration/storage/test_cosmosdb_storage.py @@ -22,9 +22,9 @@ async def test_find(): storage = AzureCosmosStorage( - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - base_dir="testfind", - container_name="testfindcontainer", + azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + azure_cosmosdb_database_name="testfind", + azure_container_name="testfindcontainer", ) try: try: @@ -65,9 +65,9 @@ async def test_find(): async def test_child(): storage = AzureCosmosStorage( - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - base_dir="testchild", - container_name="testchildcontainer", + azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + azure_cosmosdb_database_name="testchild", + azure_container_name="testchildcontainer", ) try: child_storage = storage.child("child") @@ -78,9 +78,9 @@ async def test_child(): async def test_clear(): storage = AzureCosmosStorage( - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - base_dir="testclear", - container_name="testclearcontainer", + azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + azure_cosmosdb_database_name="testclear", + azure_container_name="testclearcontainer", ) try: json_exists = { @@ -108,9 +108,9 @@ async def test_clear(): async def test_get_creation_date(): storage = AzureCosmosStorage( - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - base_dir="testclear", - container_name="testclearcontainer", + azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + azure_cosmosdb_database_name="testclear", + azure_container_name="testclearcontainer", ) try: json_content = { diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 2b06d6e711..32d8f9b26d 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -30,9 +30,9 @@ def test_create_blob_storage(): config = StorageConfig( type=StorageType.AZURE_BLOB, - connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, base_dir="testbasedir", - container_name="testcontainer", + azure_container_name="testcontainer", ) storage = create_storage(config) assert isinstance(storage, AzureBlobStorage) @@ -45,9 +45,9 @@ def test_create_blob_storage(): def test_create_cosmosdb_storage(): config = StorageConfig( type=StorageType.AZURE_COSMOS, - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - base_dir="testdatabase", - container_name="testcontainer", + azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + azure_cosmosdb_database_name="testdatabase", + azure_container_name="testcontainer", ) storage = create_storage(config) assert isinstance(storage, AzureCosmosStorage) diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 53205c7c09..2e3b2b09e3 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -95,8 +95,8 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non root = Path(input_path) input_storage = AzureBlobStorage( - connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING, - container_name=input_container, + azure_connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING, + azure_container_name=input_container, ) # Bounce the container if it exists to clear out old run data input_storage._delete_container() # noqa: SLF001 diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 0f2235a58b..522dcc87c4 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -129,10 +129,12 @@ def assert_reporting_configs( def assert_output_configs(actual: StorageConfig, expected: StorageConfig) -> None: assert expected.type == actual.type assert expected.base_dir == actual.base_dir - assert expected.connection_string == actual.connection_string - assert expected.container_name == actual.container_name - assert expected.storage_account_blob_url == actual.storage_account_blob_url - assert expected.cosmosdb_account_url == actual.cosmosdb_account_url + assert expected.azure_connection_string == actual.azure_connection_string + assert expected.azure_container_name == actual.azure_container_name + assert ( + expected.azure_storage_account_blob_url == actual.azure_storage_account_blob_url + ) + assert expected.azure_cosmosdb_account_url == actual.azure_cosmosdb_account_url def assert_update_output_configs( @@ -140,10 +142,12 @@ def assert_update_output_configs( ) -> None: assert expected.type == actual.type assert expected.base_dir == actual.base_dir - assert expected.connection_string == actual.connection_string - assert expected.container_name == actual.container_name - assert expected.storage_account_blob_url == actual.storage_account_blob_url - assert expected.cosmosdb_account_url == actual.cosmosdb_account_url + assert expected.azure_connection_string == actual.azure_connection_string + assert expected.azure_container_name == actual.azure_container_name + assert ( + expected.azure_storage_account_blob_url == actual.azure_storage_account_blob_url + ) + assert expected.azure_cosmosdb_account_url == actual.azure_cosmosdb_account_url def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: @@ -159,12 +163,15 @@ def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None: assert actual.storage.type == expected.storage.type assert actual.file_type == expected.file_type assert actual.storage.base_dir == expected.storage.base_dir - assert actual.storage.connection_string == expected.storage.connection_string assert ( - actual.storage.storage_account_blob_url - == expected.storage.storage_account_blob_url + actual.storage.azure_connection_string + == expected.storage.azure_connection_string + ) + assert ( + actual.storage.azure_storage_account_blob_url + == expected.storage.azure_storage_account_blob_url ) - assert actual.storage.container_name == expected.storage.container_name + assert actual.storage.azure_container_name == expected.storage.azure_container_name assert actual.encoding == expected.encoding assert actual.file_pattern == expected.file_pattern assert actual.text_column == expected.text_column From 3b01e27dce2b7f88f155b12c68467264e78d8fbe Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 2 Dec 2025 14:21:14 -0800 Subject: [PATCH 11/17] fix integration tests. --- tests/integration/cache/test_factory.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/integration/cache/test_factory.py b/tests/integration/cache/test_factory.py index 766adc1f8f..8313e36bdc 100644 --- a/tests/integration/cache/test_factory.py +++ b/tests/integration/cache/test_factory.py @@ -41,8 +41,8 @@ def test_create_file_cache(): def test_create_blob_cache(): init_args = { - "connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, - "container_name": "testcontainer", + "azure_connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, + "azure_container_name": "testcontainer", "base_dir": "testcache", } cache = CacheFactory().create(strategy=CacheType.blob.value, init_args=init_args) @@ -55,9 +55,9 @@ def test_create_blob_cache(): ) def test_create_cosmosdb_cache(): init_args = { - "connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, - "base_dir": "testdatabase", - "container_name": "testcontainer", + "azure_connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, + "azure_cosmosdb_database_name": "testdatabase", + "azure_container_name": "testcontainer", } cache = CacheFactory().create( strategy=CacheType.cosmosdb.value, init_args=init_args From 0529dfeffcd89d14df35b6031e4292ad59ec737e Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Wed, 3 Dec 2025 07:25:04 -0800 Subject: [PATCH 12/17] update storage config. --- docs/config/yaml.md | 6 +-- packages/graphrag-storage/README.md | 2 +- .../graphrag_storage/azure_blob_storage.py | 19 ++++----- .../graphrag_storage/azure_cosmos_storage.py | 17 ++++---- .../graphrag_storage/storage_config.py | 12 ++---- .../graphrag_storage/storage_type.py | 8 ++-- packages/graphrag/graphrag/cache/factory.py | 9 +++-- .../graphrag/graphrag/config/init_content.py | 4 +- tests/unit/config/utils.py | 40 ++++--------------- 9 files changed, 42 insertions(+), 75 deletions(-) diff --git a/docs/config/yaml.md b/docs/config/yaml.md index 19c1cfd05e..1574f0ad1d 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -81,7 +81,7 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th #### Fields - `storage` **StorageConfig** - - `type` **file|azure_blob|azure_cosmos** - The storage type to use. Default=`file` + - `type` **File|AzureBlob|AzureCosmos** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -115,7 +115,7 @@ This section controls the storage mechanism used by the pipeline used for export #### Fields -- `type` **file|azure_blob|azure_cosmos** - The storage type to use. Default=`file` +- `type` **File|AzureBlob|AzureCosmos** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -128,7 +128,7 @@ The section defines a secondary storage location for running incremental indexin #### Fields -- `type` **file|azure_blob|azure_cosmos** - The storage type to use. Default=`file` +- `type` **File|AzureBlob|AzureCosmos** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md index ffda7441fb..c8cd117d1b 100644 --- a/packages/graphrag-storage/README.md +++ b/packages/graphrag-storage/README.md @@ -29,7 +29,7 @@ from typing import Any from graphrag_storage import Storage, StorageConfig, create_storage, register_storage class MyStorage(Storage): - def __init__(self, some_setting: str): + def __init__(self, some_setting: str, optional_setting: str = "default setting"): # Validate settings and initialize ... diff --git a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index bc334ed84d..57ab9961b9 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -34,16 +34,13 @@ class AzureBlobStorage(Storage): def __init__( self, azure_container_name: str, - azure_storage_account_blob_url: str | None = None, + azure_account_url: str | None = None, azure_connection_string: str | None = None, base_dir: str | None = None, encoding: str = "utf-8", ) -> None: """Create a new BlobStorage instance.""" - if ( - azure_connection_string is not None - and azure_storage_account_blob_url is not None - ): + if azure_connection_string is not None and azure_account_url is not None: msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both." logger.error(msg) raise ValueError(msg) @@ -59,9 +56,9 @@ def __init__( self._blob_service_client = BlobServiceClient.from_connection_string( azure_connection_string ) - elif azure_storage_account_blob_url: + elif azure_account_url: self._blob_service_client = BlobServiceClient( - account_url=azure_storage_account_blob_url, + account_url=azure_account_url, credential=DefaultAzureCredential(), ) else: @@ -73,10 +70,10 @@ def __init__( self._container_name = azure_container_name self._connection_string = azure_connection_string self._base_dir = base_dir - self._storage_account_blob_url = azure_storage_account_blob_url + self._storage_account_blob_url = azure_account_url self._storage_account_name = ( - azure_storage_account_blob_url.split("//")[1].split(".")[0] - if azure_storage_account_blob_url + azure_account_url.split("//")[1].split(".")[0] + if azure_account_url else None ) self._create_container() @@ -229,7 +226,7 @@ def child(self, name: str | None) -> "Storage": azure_container_name=self._container_name, encoding=self._encoding, base_dir=path, - azure_storage_account_blob_url=self._storage_account_blob_url, + azure_account_url=self._storage_account_blob_url, ) def keys(self) -> list[str]: diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index 04d2dd4de8..aa8b1ec70f 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -44,7 +44,7 @@ def __init__( azure_cosmosdb_database_name: str, azure_container_name: str, azure_connection_string: str | None = None, - azure_cosmosdb_account_url: str | None = None, + azure_account_url: str | None = None, encoding: str = "utf-8", ) -> None: """Create a CosmosDB storage instance.""" @@ -55,10 +55,7 @@ def __init__( logger.error(msg) raise ValueError(msg) - if ( - azure_connection_string is not None - and azure_cosmosdb_account_url is not None - ): + if azure_connection_string is not None and azure_account_url is not None: msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified, not both." logger.error(msg) raise ValueError(msg) @@ -67,9 +64,9 @@ def __init__( self._cosmos_client = CosmosClient.from_connection_string( azure_connection_string ) - elif azure_cosmosdb_account_url: + elif azure_account_url: self._cosmos_client = CosmosClient( - url=azure_cosmosdb_account_url, + url=azure_account_url, credential=DefaultAzureCredential(), ) else: @@ -80,11 +77,11 @@ def __init__( self._encoding = encoding self._database_name = database_name self._connection_string = azure_connection_string - self._cosmosdb_account_url = azure_cosmosdb_account_url + self._cosmosdb_account_url = azure_account_url self._container_name = azure_container_name self._cosmosdb_account_name = ( - azure_cosmosdb_account_url.split("//")[1].split(".")[0] - if azure_cosmosdb_account_url + azure_account_url.split("//")[1].split(".")[0] + if azure_account_url else None ) self._no_id_prefixes = [] diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py index b0e2253f20..60766ce31d 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_config.py +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -15,7 +15,7 @@ class StorageConfig(BaseModel): """Allow extra fields to support custom storage implementations.""" type: str = Field( - description="The storage type to use. Builtin types include 'file', 'azure_blob', and 'azure_cosmos'.", + description="The storage type to use. Builtin types include 'File', 'AzureBlob', and 'AzureCosmos'.", default=StorageType.FILE, ) @@ -25,7 +25,7 @@ class StorageConfig(BaseModel): ) base_dir: str | None = Field( - description="The base directory for the output when using file or azure_blob storage.", + description="The base directory for the output when using file or AzureBlob storage.", default=None, ) @@ -38,15 +38,11 @@ class StorageConfig(BaseModel): description="The Azure Blob Storage container name or CosmosDB container name to use.", default=None, ) - azure_storage_account_blob_url: str | None = Field( - description="The Azure Blob Storage account blob url to use.", + azure_account_url: str | None = Field( + description="The account url for Azure Blob Storage or Azure CosmosDB.", default=None, ) azure_cosmosdb_database_name: str | None = Field( description="The Azure CosmosDB database name to use.", default=None, ) - azure_cosmosdb_account_url: str | None = Field( - description="The Azure CosmosDB account url to use.", - default=None, - ) diff --git a/packages/graphrag-storage/graphrag_storage/storage_type.py b/packages/graphrag-storage/graphrag_storage/storage_type.py index 89a59b00d8..9577d33b9e 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_type.py +++ b/packages/graphrag-storage/graphrag_storage/storage_type.py @@ -10,7 +10,7 @@ class StorageType(StrEnum): """Enum for storage types.""" - FILE = "file" - MEMORY = "memory" - AZURE_BLOB = "azure_blob" - AZURE_COSMOS = "azure_cosmos" + FILE = "File" + MEMORY = "Memory" + AZURE_BLOB = "AzureBlob" + AZURE_COSMOS = "AzureCosmos" diff --git a/packages/graphrag/graphrag/cache/factory.py b/packages/graphrag/graphrag/cache/factory.py index 2660d0fea2..e8c23ded4f 100644 --- a/packages/graphrag/graphrag/cache/factory.py +++ b/packages/graphrag/graphrag/cache/factory.py @@ -6,9 +6,6 @@ from __future__ import annotations from graphrag_common.factory import Factory -from graphrag_storage.azure_blob_storage import AzureBlobStorage -from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage -from graphrag_storage.file_storage import FileStorage from graphrag.cache.json_pipeline_cache import JsonPipelineCache from graphrag.cache.memory_pipeline_cache import InMemoryCache @@ -30,6 +27,8 @@ class CacheFactory(Factory[PipelineCache]): # --- register built-in cache implementations --- def create_file_cache(**kwargs) -> PipelineCache: """Create a file-based cache implementation.""" + from graphrag_storage.file_storage import FileStorage + kwargs.pop("type", None) storage = FileStorage(**kwargs) return JsonPipelineCache(storage) @@ -37,6 +36,8 @@ def create_file_cache(**kwargs) -> PipelineCache: def create_blob_cache(**kwargs) -> PipelineCache: """Create a blob storage-based cache implementation.""" + from graphrag_storage.azure_blob_storage import AzureBlobStorage + kwargs.pop("type", None) storage = AzureBlobStorage(**kwargs) return JsonPipelineCache(storage) @@ -44,6 +45,8 @@ def create_blob_cache(**kwargs) -> PipelineCache: def create_cosmosdb_cache(**kwargs) -> PipelineCache: """Create a CosmosDB-based cache implementation.""" + from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage + kwargs.pop("type", None) storage = AzureCosmosStorage(**kwargs) return JsonPipelineCache(storage) diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index 59025a3259..831c0fbba2 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -50,7 +50,7 @@ input: storage: - type: {graphrag_config_defaults.input.storage.type} # or azure_blob, azure_cosmos + type: {graphrag_config_defaults.input.storage.type} # or AzureBlob, AzureCosmos base_dir: "{graphrag_config_defaults.input.storage.base_dir}" file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] @@ -63,7 +63,7 @@ ## connection_string and container_name must be provided output: - type: {graphrag_config_defaults.output.type} # or azure_blob, azure_cosmos + type: {graphrag_config_defaults.output.type} # or AzureBlob, AzureCosmos base_dir: "{graphrag_config_defaults.output.base_dir}" cache: diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 522dcc87c4..3c2e5320e2 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -126,28 +126,14 @@ def assert_reporting_configs( assert actual.storage_account_blob_url == expected.storage_account_blob_url -def assert_output_configs(actual: StorageConfig, expected: StorageConfig) -> None: +def assert_storage_config(actual: StorageConfig, expected: StorageConfig) -> None: assert expected.type == actual.type assert expected.base_dir == actual.base_dir assert expected.azure_connection_string == actual.azure_connection_string assert expected.azure_container_name == actual.azure_container_name - assert ( - expected.azure_storage_account_blob_url == actual.azure_storage_account_blob_url - ) - assert expected.azure_cosmosdb_account_url == actual.azure_cosmosdb_account_url - - -def assert_update_output_configs( - actual: StorageConfig, expected: StorageConfig -) -> None: - assert expected.type == actual.type - assert expected.base_dir == actual.base_dir - assert expected.azure_connection_string == actual.azure_connection_string - assert expected.azure_container_name == actual.azure_container_name - assert ( - expected.azure_storage_account_blob_url == actual.azure_storage_account_blob_url - ) - assert expected.azure_cosmosdb_account_url == actual.azure_cosmosdb_account_url + assert expected.azure_account_url == actual.azure_account_url + assert expected.encoding == actual.encoding + assert expected.azure_cosmosdb_database_name == actual.azure_cosmosdb_database_name def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: @@ -160,18 +146,8 @@ def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None: - assert actual.storage.type == expected.storage.type + assert_storage_config(actual.storage, expected.storage) assert actual.file_type == expected.file_type - assert actual.storage.base_dir == expected.storage.base_dir - assert ( - actual.storage.azure_connection_string - == expected.storage.azure_connection_string - ) - assert ( - actual.storage.azure_storage_account_blob_url - == expected.storage.azure_storage_account_blob_url - ) - assert actual.storage.azure_container_name == expected.storage.azure_container_name assert actual.encoding == expected.encoding assert actual.file_pattern == expected.file_pattern assert actual.text_column == expected.text_column @@ -365,11 +341,9 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) -> assert_vector_store_configs(actual.vector_store, expected.vector_store) assert_reporting_configs(actual.reporting, expected.reporting) - assert_output_configs(actual.output, expected.output) + assert_storage_config(actual.output, expected.output) - assert_update_output_configs( - actual.update_index_output, expected.update_index_output - ) + assert_storage_config(actual.update_index_output, expected.update_index_output) assert_cache_configs(actual.cache, expected.cache) assert_input_configs(actual.input, expected.input) From 3a472331ca753d7a8d4f64d94f952093ee4ec0a0 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Wed, 3 Dec 2025 10:16:00 -0800 Subject: [PATCH 13/17] updates --- docs/config/yaml.md | 6 ++-- packages/graphrag-storage/README.md | 29 +++++++++++++++++-- .../graphrag_storage/azure_blob_storage.py | 1 + .../graphrag_storage/azure_cosmos_storage.py | 1 + .../graphrag_storage/file_storage.py | 2 +- .../graphrag_storage/storage.py | 4 +++ .../graphrag_storage/storage_config.py | 2 +- .../graphrag_storage/storage_factory.py | 27 +++++++---------- .../graphrag_storage/storage_type.py | 8 ++--- packages/graphrag/graphrag/cache/factory.py | 3 -- packages/graphrag/graphrag/config/defaults.py | 2 +- .../graphrag/graphrag/config/init_content.py | 4 +-- .../config/models/graph_rag_config.py | 6 ++-- tests/integration/storage/test_factory.py | 8 ++--- 14 files changed, 63 insertions(+), 40 deletions(-) diff --git a/docs/config/yaml.md b/docs/config/yaml.md index 1574f0ad1d..c82ff5bb12 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -81,7 +81,7 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th #### Fields - `storage` **StorageConfig** - - `type` **File|AzureBlob|AzureCosmos** - The storage type to use. Default=`file` + - `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -115,7 +115,7 @@ This section controls the storage mechanism used by the pipeline used for export #### Fields -- `type` **File|AzureBlob|AzureCosmos** - The storage type to use. Default=`file` +- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -128,7 +128,7 @@ The section defines a secondary storage location for running incremental indexin #### Fields -- `type` **File|AzureBlob|AzureCosmos** - The storage type to use. Default=`file` +- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md index c8cd117d1b..cb04667b64 100644 --- a/packages/graphrag-storage/README.md +++ b/packages/graphrag-storage/README.md @@ -9,7 +9,7 @@ from graphrag_storage import StorageConfig, create_storage, StorageType async def run(): storage = create_storage( StorageConfig( - type=StorageType.FILE + type=StorageType.File base_dir="output" ) ) @@ -29,7 +29,7 @@ from typing import Any from graphrag_storage import Storage, StorageConfig, create_storage, register_storage class MyStorage(Storage): - def __init__(self, some_setting: str, optional_setting: str = "default setting"): + def __init__(self, some_setting: str, optional_setting: str = "default setting", **kwargs: Any): # Validate settings and initialize ... @@ -55,3 +55,28 @@ async def run(): if __name__ == "__main__": asyncio.run(run()) +``` + +### Information + +By default, the `create_storage` comes with the following storage providers registered that correspond to the entries in the `StorageType` enum. + +- `FileStorage` +- `AzureBlobStorage` +- `AzureCosmosStorage` +- `MemoryStorage` + +You can directly import `storage_factory` if you want a clean factory with no preregistered storage providers. + +```python +from graphrag_storage.storage_factory import storage_factory +from graphrag_storage.file_storage import FileStorage + +# Or register a custom implementation, see above for example. +storage_factory.register("my_storage_key", FileStorage) + +storage = storage_factory.create(strategy="my_storage_key", init_args={"base_dir": "...", "other_settings": "...",}) + +... + +``` \ No newline at end of file diff --git a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index 57ab9961b9..35001e48e5 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -38,6 +38,7 @@ def __init__( azure_connection_string: str | None = None, base_dir: str | None = None, encoding: str = "utf-8", + **kwargs: Any, ) -> None: """Create a new BlobStorage instance.""" if azure_connection_string is not None and azure_account_url is not None: diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index aa8b1ec70f..78a63eb2cb 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -46,6 +46,7 @@ def __init__( azure_connection_string: str | None = None, azure_account_url: str | None = None, encoding: str = "utf-8", + **kwargs: Any, ) -> None: """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") diff --git a/packages/graphrag-storage/graphrag_storage/file_storage.py b/packages/graphrag-storage/graphrag_storage/file_storage.py index 6c9f845734..547659abcd 100644 --- a/packages/graphrag-storage/graphrag_storage/file_storage.py +++ b/packages/graphrag-storage/graphrag_storage/file_storage.py @@ -30,7 +30,7 @@ class FileStorage(Storage): _base_dir: Path _encoding: str - def __init__(self, base_dir: str, encoding: str = "utf-8") -> None: + def __init__(self, base_dir: str, encoding: str = "utf-8", **kwargs: Any) -> None: """Create a file based storage.""" self._base_dir = Path(base_dir).resolve() self._encoding = encoding diff --git a/packages/graphrag-storage/graphrag_storage/storage.py b/packages/graphrag-storage/graphrag_storage/storage.py index e356af2d64..d8016d4ae7 100644 --- a/packages/graphrag-storage/graphrag_storage/storage.py +++ b/packages/graphrag-storage/graphrag_storage/storage.py @@ -13,6 +13,10 @@ class Storage(ABC): """Provide a storage interface.""" + @abstractmethod + def __init__(self, **kwargs: Any) -> None: + """Create a storage instance.""" + @abstractmethod def find( self, diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py index 60766ce31d..abcd420010 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_config.py +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -16,7 +16,7 @@ class StorageConfig(BaseModel): type: str = Field( description="The storage type to use. Builtin types include 'File', 'AzureBlob', and 'AzureCosmos'.", - default=StorageType.FILE, + default=StorageType.File, ) encoding: str | None = Field( diff --git a/packages/graphrag-storage/graphrag_storage/storage_factory.py b/packages/graphrag-storage/graphrag_storage/storage_factory.py index caea8db15d..1341d7ecf4 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_factory.py +++ b/packages/graphrag-storage/graphrag_storage/storage_factory.py @@ -13,11 +13,11 @@ from graphrag_storage.storage_type import StorageType -class _StorageFactory(Factory[Storage]): +class StorageFactory(Factory[Storage]): """A factory class for storage implementations.""" -storage_factory = _StorageFactory() +storage_factory = StorageFactory() def register_storage( @@ -49,31 +49,26 @@ def create_storage(config: StorageConfig) -> Storage: The created storage implementation. """ config_model = config.model_dump() - storage_strategy = config_model.pop("type") - - # Check storage_strategy is a string - if not isinstance(storage_strategy, str): - msg = f"StorageConfig.type must be a string, got {type(storage_strategy)}" - raise TypeError(msg) + storage_strategy = config.type if storage_strategy not in storage_factory: match storage_strategy: - case StorageType.FILE: + case StorageType.File: from graphrag_storage.file_storage import FileStorage - register_storage(StorageType.FILE, FileStorage) - case StorageType.MEMORY: + register_storage(StorageType.File, FileStorage) + case StorageType.Memory: from graphrag_storage.memory_storage import MemoryStorage - register_storage(StorageType.MEMORY, MemoryStorage) - case StorageType.AZURE_BLOB: + register_storage(StorageType.Memory, MemoryStorage) + case StorageType.AzureBlob: from graphrag_storage.azure_blob_storage import AzureBlobStorage - register_storage(StorageType.AZURE_BLOB, AzureBlobStorage) - case StorageType.AZURE_COSMOS: + register_storage(StorageType.AzureBlob, AzureBlobStorage) + case StorageType.AzureCosmos: from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage - register_storage(StorageType.AZURE_COSMOS, AzureCosmosStorage) + register_storage(StorageType.AzureCosmos, AzureCosmosStorage) case _: msg = f"StorageConfig.type '{storage_strategy}' is not registered in the StorageFactory. Registered types: {', '.join(storage_factory.keys())}." raise ValueError(msg) diff --git a/packages/graphrag-storage/graphrag_storage/storage_type.py b/packages/graphrag-storage/graphrag_storage/storage_type.py index 9577d33b9e..dd2b1376fd 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_type.py +++ b/packages/graphrag-storage/graphrag_storage/storage_type.py @@ -10,7 +10,7 @@ class StorageType(StrEnum): """Enum for storage types.""" - FILE = "File" - MEMORY = "Memory" - AZURE_BLOB = "AzureBlob" - AZURE_COSMOS = "AzureCosmos" + File = "file" + Memory = "memory" + AzureBlob = "blob" + AzureCosmos = "cosmosdb" diff --git a/packages/graphrag/graphrag/cache/factory.py b/packages/graphrag/graphrag/cache/factory.py index e8c23ded4f..d62d8c420b 100644 --- a/packages/graphrag/graphrag/cache/factory.py +++ b/packages/graphrag/graphrag/cache/factory.py @@ -29,7 +29,6 @@ def create_file_cache(**kwargs) -> PipelineCache: """Create a file-based cache implementation.""" from graphrag_storage.file_storage import FileStorage - kwargs.pop("type", None) storage = FileStorage(**kwargs) return JsonPipelineCache(storage) @@ -38,7 +37,6 @@ def create_blob_cache(**kwargs) -> PipelineCache: """Create a blob storage-based cache implementation.""" from graphrag_storage.azure_blob_storage import AzureBlobStorage - kwargs.pop("type", None) storage = AzureBlobStorage(**kwargs) return JsonPipelineCache(storage) @@ -47,7 +45,6 @@ def create_cosmosdb_cache(**kwargs) -> PipelineCache: """Create a CosmosDB-based cache implementation.""" from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage - kwargs.pop("type", None) storage = AzureCosmosStorage(**kwargs) return JsonPipelineCache(storage) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index e70f321a04..cefcc676f3 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -231,7 +231,7 @@ class GlobalSearchDefaults: class StorageDefaults: """Default values for storage.""" - type: str = StorageType.FILE + type: str = StorageType.File encoding: str | None = None base_dir: str | None = None azure_connection_string: None = None diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index 831c0fbba2..01554d9ce5 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -50,7 +50,7 @@ input: storage: - type: {graphrag_config_defaults.input.storage.type} # or AzureBlob, AzureCosmos + type: {graphrag_config_defaults.input.storage.type} # or blob, cosmosdb base_dir: "{graphrag_config_defaults.input.storage.base_dir}" file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] @@ -63,7 +63,7 @@ ## connection_string and container_name must be provided output: - type: {graphrag_config_defaults.output.type} # or AzureBlob, AzureCosmos + type: {graphrag_config_defaults.output.type} # or blob, cosmosdb base_dir: "{graphrag_config_defaults.output.base_dir}" cache: diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index d19f8f733d..101a37b379 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -116,7 +116,7 @@ def _validate_input_pattern(self) -> None: def _validate_input_base_dir(self) -> None: """Validate the input base directory.""" - if self.input.storage.type == StorageType.FILE: + if self.input.storage.type == StorageType.File: if not self.input.storage.base_dir: msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." raise ValueError(msg) @@ -140,7 +140,7 @@ def _validate_input_base_dir(self) -> None: def _validate_output_base_dir(self) -> None: """Validate the output base directory.""" - if self.output.type == StorageType.FILE: + if self.output.type == StorageType.File: if not self.output.base_dir: msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) @@ -156,7 +156,7 @@ def _validate_output_base_dir(self) -> None: def _validate_update_index_output_base_dir(self) -> None: """Validate the update index output base directory.""" - if self.update_index_output.type == StorageType.FILE: + if self.update_index_output.type == StorageType.File: if not self.update_index_output.base_dir: msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 32d8f9b26d..8e370d50f6 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -29,7 +29,7 @@ @pytest.mark.skip(reason="Blob storage emulator is not available in this environment") def test_create_blob_storage(): config = StorageConfig( - type=StorageType.AZURE_BLOB, + type=StorageType.AzureBlob, azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, base_dir="testbasedir", azure_container_name="testcontainer", @@ -44,7 +44,7 @@ def test_create_blob_storage(): ) def test_create_cosmosdb_storage(): config = StorageConfig( - type=StorageType.AZURE_COSMOS, + type=StorageType.AzureCosmos, azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, azure_cosmosdb_database_name="testdatabase", azure_container_name="testcontainer", @@ -55,7 +55,7 @@ def test_create_cosmosdb_storage(): def test_create_file(): config = StorageConfig( - type=StorageType.FILE, + type=StorageType.File, base_dir="/tmp/teststorage", ) storage = create_storage(config) @@ -65,7 +65,7 @@ def test_create_file(): def test_create_memory_storage(): config = StorageConfig( base_dir="", - type=StorageType.MEMORY, + type=StorageType.Memory, ) storage = create_storage(config) assert isinstance(storage, MemoryStorage) From 3dd1c2848170eb4b08ea0f507bef40e1d69b4927 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Wed, 3 Dec 2025 10:49:43 -0800 Subject: [PATCH 14/17] update readme --- packages/graphrag-storage/README.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md index cb04667b64..6eaf729242 100644 --- a/packages/graphrag-storage/README.md +++ b/packages/graphrag-storage/README.md @@ -57,25 +57,29 @@ if __name__ == "__main__": asyncio.run(run()) ``` -### Information +### Details -By default, the `create_storage` comes with the following storage providers registered that correspond to the entries in the `StorageType` enum. +By default, the `create_storage` comes with the following storage providers registered that correspond to the entries in the `StorageType` enum. - `FileStorage` - `AzureBlobStorage` - `AzureCosmosStorage` - `MemoryStorage` -You can directly import `storage_factory` if you want a clean factory with no preregistered storage providers. +The preregistration happens dynamically, e.g., `FileStorage` is only imported and registered if you request a `FileStorage` with `create_storage(StorageType.File, ...)`. There is no need to manually import and register builtin storage providers when using `create_storage`. + +If you want a clean factory with no preregistered storage providers then directly import `storage_factory` and bypass using `create_storage`. The downside is that `storage_factory.create` uses a dict for init args instead of the strongly typed `StorageConfig` used with `create_storage`. ```python from graphrag_storage.storage_factory import storage_factory from graphrag_storage.file_storage import FileStorage -# Or register a custom implementation, see above for example. +# storage_factory has no preregistered providers so you must register any +# providers you plan on using. +# May also register a custom implementation, see above for example. storage_factory.register("my_storage_key", FileStorage) -storage = storage_factory.create(strategy="my_storage_key", init_args={"base_dir": "...", "other_settings": "...",}) +storage = storage_factory.create(strategy="my_storage_key", init_args={"base_dir": "...", "other_settings": "..."}) ... From ef97d7bcba38015d4d0bef6d0d9546050297373c Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Fri, 12 Dec 2025 07:30:04 -0800 Subject: [PATCH 15/17] cleanup --- packages/graphrag/graphrag/storage/__init__.py | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 packages/graphrag/graphrag/storage/__init__.py diff --git a/packages/graphrag/graphrag/storage/__init__.py b/packages/graphrag/graphrag/storage/__init__.py deleted file mode 100644 index 94146bcd02..0000000000 --- a/packages/graphrag/graphrag/storage/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The storage package root.""" - -from graphrag_storage import create_storage, register_storage - -__all__ = ["create_storage", "register_storage"] From 4b7d2d7f8ad2565fe1f3f66fdd77d6dbe0622193 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Mon, 15 Dec 2025 06:10:57 -0800 Subject: [PATCH 16/17] cleanup config --- .../graphrag_storage/azure_blob_storage.py | 36 +++++++++---------- .../graphrag_storage/azure_cosmos_storage.py | 32 ++++++++--------- .../graphrag_storage/storage_config.py | 14 ++++---- tests/integration/cache/test_factory.py | 10 +++--- .../integration/storage/test_blob_storage.py | 12 +++---- .../storage/test_cosmosdb_storage.py | 24 ++++++------- tests/integration/storage/test_factory.py | 10 +++--- tests/smoke/test_fixtures.py | 4 +-- tests/unit/config/utils.py | 8 ++--- 9 files changed, 72 insertions(+), 78 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index 35001e48e5..bec1bdb465 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -33,33 +33,33 @@ class AzureBlobStorage(Storage): def __init__( self, - azure_container_name: str, - azure_account_url: str | None = None, - azure_connection_string: str | None = None, + container_name: str, + account_url: str | None = None, + connection_string: str | None = None, base_dir: str | None = None, encoding: str = "utf-8", **kwargs: Any, ) -> None: """Create a new BlobStorage instance.""" - if azure_connection_string is not None and azure_account_url is not None: + if connection_string is not None and account_url is not None: msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both." logger.error(msg) raise ValueError(msg) - _validate_blob_container_name(azure_container_name) + _validate_blob_container_name(container_name) logger.info( "Creating blob storage at [%s] and base_dir [%s]", - azure_container_name, + container_name, base_dir, ) - if azure_connection_string: + if connection_string: self._blob_service_client = BlobServiceClient.from_connection_string( - azure_connection_string + connection_string ) - elif azure_account_url: + elif account_url: self._blob_service_client = BlobServiceClient( - account_url=azure_account_url, + account_url=account_url, credential=DefaultAzureCredential(), ) else: @@ -68,14 +68,12 @@ def __init__( raise ValueError(msg) self._encoding = encoding - self._container_name = azure_container_name - self._connection_string = azure_connection_string + self._container_name = container_name + self._connection_string = connection_string self._base_dir = base_dir - self._storage_account_blob_url = azure_account_url + self._storage_account_blob_url = account_url self._storage_account_name = ( - azure_account_url.split("//")[1].split(".")[0] - if azure_account_url - else None + account_url.split("//")[1].split(".")[0] if account_url else None ) self._create_container() @@ -223,11 +221,11 @@ def child(self, name: str | None) -> "Storage": return self path = str(Path(self._base_dir) / name) if self._base_dir else name return AzureBlobStorage( - azure_connection_string=self._connection_string, - azure_container_name=self._container_name, + connection_string=self._connection_string, + container_name=self._container_name, encoding=self._encoding, base_dir=path, - azure_account_url=self._storage_account_blob_url, + account_url=self._storage_account_blob_url, ) def keys(self) -> list[str]: diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index 78a63eb2cb..ff3ec6decb 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -41,33 +41,31 @@ class AzureCosmosStorage(Storage): def __init__( self, - azure_cosmosdb_database_name: str, - azure_container_name: str, - azure_connection_string: str | None = None, - azure_account_url: str | None = None, + database_name: str, + container_name: str, + connection_string: str | None = None, + account_url: str | None = None, encoding: str = "utf-8", **kwargs: Any, ) -> None: """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") - database_name = azure_cosmosdb_database_name + database_name = database_name if database_name is None: msg = "CosmosDB Storage requires a base_dir to be specified. This is used as the database name." logger.error(msg) raise ValueError(msg) - if azure_connection_string is not None and azure_account_url is not None: + if connection_string is not None and account_url is not None: msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified, not both." logger.error(msg) raise ValueError(msg) - if azure_connection_string: - self._cosmos_client = CosmosClient.from_connection_string( - azure_connection_string - ) - elif azure_account_url: + if connection_string: + self._cosmos_client = CosmosClient.from_connection_string(connection_string) + elif account_url: self._cosmos_client = CosmosClient( - url=azure_account_url, + url=account_url, credential=DefaultAzureCredential(), ) else: @@ -77,13 +75,11 @@ def __init__( self._encoding = encoding self._database_name = database_name - self._connection_string = azure_connection_string - self._cosmosdb_account_url = azure_account_url - self._container_name = azure_container_name + self._connection_string = connection_string + self._cosmosdb_account_url = account_url + self._container_name = container_name self._cosmosdb_account_name = ( - azure_account_url.split("//")[1].split(".")[0] - if azure_account_url - else None + account_url.split("//")[1].split(".")[0] if account_url else None ) self._no_id_prefixes = [] logger.debug( diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py index abcd420010..45cd63a734 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_config.py +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -29,20 +29,20 @@ class StorageConfig(BaseModel): default=None, ) - azure_connection_string: str | None = Field( - description="The connection string for Azure Blob Storage or Azure CosmosDB.", + connection_string: str | None = Field( + description="The connection string for remote services.", default=None, ) - azure_container_name: str | None = Field( + container_name: str | None = Field( description="The Azure Blob Storage container name or CosmosDB container name to use.", default=None, ) - azure_account_url: str | None = Field( - description="The account url for Azure Blob Storage or Azure CosmosDB.", + account_url: str | None = Field( + description="The account url for Azure services.", default=None, ) - azure_cosmosdb_database_name: str | None = Field( - description="The Azure CosmosDB database name to use.", + database_name: str | None = Field( + description="The database name to use.", default=None, ) diff --git a/tests/integration/cache/test_factory.py b/tests/integration/cache/test_factory.py index 8313e36bdc..361d46a4ea 100644 --- a/tests/integration/cache/test_factory.py +++ b/tests/integration/cache/test_factory.py @@ -41,8 +41,8 @@ def test_create_file_cache(): def test_create_blob_cache(): init_args = { - "azure_connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, - "azure_container_name": "testcontainer", + "connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, + "container_name": "testcontainer", "base_dir": "testcache", } cache = CacheFactory().create(strategy=CacheType.blob.value, init_args=init_args) @@ -55,9 +55,9 @@ def test_create_blob_cache(): ) def test_create_cosmosdb_cache(): init_args = { - "azure_connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, - "azure_cosmosdb_database_name": "testdatabase", - "azure_container_name": "testcontainer", + "connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, + "database_name": "testdatabase", + "container_name": "testcontainer", } cache = CacheFactory().create( strategy=CacheType.cosmosdb.value, init_args=init_args diff --git a/tests/integration/storage/test_blob_storage.py b/tests/integration/storage/test_blob_storage.py index 9b1654aa70..ec996e91a2 100644 --- a/tests/integration/storage/test_blob_storage.py +++ b/tests/integration/storage/test_blob_storage.py @@ -13,8 +13,8 @@ async def test_find(): storage = AzureBlobStorage( - azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, - azure_container_name="testfind", + connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + container_name="testfind", ) try: try: @@ -43,8 +43,8 @@ async def test_find(): async def test_get_creation_date(): storage = AzureBlobStorage( - azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, - azure_container_name="testfind", + connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + container_name="testfind", ) try: await storage.set("input/christmas.txt", "Merry Christmas!", encoding="utf-8") @@ -60,8 +60,8 @@ async def test_get_creation_date(): async def test_child(): parent = AzureBlobStorage( - azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, - azure_container_name="testchild", + connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + container_name="testchild", ) try: try: diff --git a/tests/integration/storage/test_cosmosdb_storage.py b/tests/integration/storage/test_cosmosdb_storage.py index 5db0d0898f..a044cc754f 100644 --- a/tests/integration/storage/test_cosmosdb_storage.py +++ b/tests/integration/storage/test_cosmosdb_storage.py @@ -22,9 +22,9 @@ async def test_find(): storage = AzureCosmosStorage( - azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - azure_cosmosdb_database_name="testfind", - azure_container_name="testfindcontainer", + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testfind", + container_name="testfindcontainer", ) try: try: @@ -65,9 +65,9 @@ async def test_find(): async def test_child(): storage = AzureCosmosStorage( - azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - azure_cosmosdb_database_name="testchild", - azure_container_name="testchildcontainer", + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testchild", + container_name="testchildcontainer", ) try: child_storage = storage.child("child") @@ -78,9 +78,9 @@ async def test_child(): async def test_clear(): storage = AzureCosmosStorage( - azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - azure_cosmosdb_database_name="testclear", - azure_container_name="testclearcontainer", + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testclear", + container_name="testclearcontainer", ) try: json_exists = { @@ -108,9 +108,9 @@ async def test_clear(): async def test_get_creation_date(): storage = AzureCosmosStorage( - azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - azure_cosmosdb_database_name="testclear", - azure_container_name="testclearcontainer", + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testclear", + container_name="testclearcontainer", ) try: json_content = { diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 8e370d50f6..f4d3092e78 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -30,9 +30,9 @@ def test_create_blob_storage(): config = StorageConfig( type=StorageType.AzureBlob, - azure_connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, base_dir="testbasedir", - azure_container_name="testcontainer", + container_name="testcontainer", ) storage = create_storage(config) assert isinstance(storage, AzureBlobStorage) @@ -45,9 +45,9 @@ def test_create_blob_storage(): def test_create_cosmosdb_storage(): config = StorageConfig( type=StorageType.AzureCosmos, - azure_connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - azure_cosmosdb_database_name="testdatabase", - azure_container_name="testcontainer", + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testdatabase", + container_name="testcontainer", ) storage = create_storage(config) assert isinstance(storage, AzureCosmosStorage) diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 2e3b2b09e3..53205c7c09 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -95,8 +95,8 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non root = Path(input_path) input_storage = AzureBlobStorage( - azure_connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING, - azure_container_name=input_container, + connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING, + container_name=input_container, ) # Bounce the container if it exists to clear out old run data input_storage._delete_container() # noqa: SLF001 diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 3c2e5320e2..83e18ad546 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -129,11 +129,11 @@ def assert_reporting_configs( def assert_storage_config(actual: StorageConfig, expected: StorageConfig) -> None: assert expected.type == actual.type assert expected.base_dir == actual.base_dir - assert expected.azure_connection_string == actual.azure_connection_string - assert expected.azure_container_name == actual.azure_container_name - assert expected.azure_account_url == actual.azure_account_url + assert expected.connection_string == actual.connection_string + assert expected.container_name == actual.container_name + assert expected.account_url == actual.account_url assert expected.encoding == actual.encoding - assert expected.azure_cosmosdb_database_name == actual.azure_cosmosdb_database_name + assert expected.database_name == actual.database_name def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: From 1460c8165ce1132c3400a30c96d7dffbf850d18a Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Mon, 15 Dec 2025 06:38:34 -0800 Subject: [PATCH 17/17] cleanup --- .../graphrag-common/graphrag_common/factory/__init__.py | 4 ++-- .../graphrag-storage/graphrag_storage/storage_factory.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/packages/graphrag-common/graphrag_common/factory/__init__.py b/packages/graphrag-common/graphrag_common/factory/__init__.py index ff9d0dce40..86b102f265 100644 --- a/packages/graphrag-common/graphrag_common/factory/__init__.py +++ b/packages/graphrag-common/graphrag_common/factory/__init__.py @@ -3,6 +3,6 @@ """The GraphRAG factory module.""" -from graphrag_common.factory.factory import Factory +from graphrag_common.factory.factory import Factory, ServiceScope -__all__ = ["Factory"] +__all__ = ["Factory", "ServiceScope"] diff --git a/packages/graphrag-storage/graphrag_storage/storage_factory.py b/packages/graphrag-storage/graphrag_storage/storage_factory.py index 1341d7ecf4..0b525fec5f 100644 --- a/packages/graphrag-storage/graphrag_storage/storage_factory.py +++ b/packages/graphrag-storage/graphrag_storage/storage_factory.py @@ -6,7 +6,7 @@ from collections.abc import Callable -from graphrag_common.factory import Factory +from graphrag_common.factory import Factory, ServiceScope from graphrag_storage.storage import Storage from graphrag_storage.storage_config import StorageConfig @@ -21,7 +21,9 @@ class StorageFactory(Factory[Storage]): def register_storage( - storage_type: str, storage_initializer: Callable[..., Storage] + storage_type: str, + storage_initializer: Callable[..., Storage], + scope: ServiceScope = "transient", ) -> None: """Register a custom storage implementation. @@ -32,7 +34,7 @@ def register_storage( - storage_initializer: Callable[..., Storage] The storage initializer to register. """ - storage_factory.register(storage_type, storage_initializer) + storage_factory.register(storage_type, storage_initializer, scope) def create_storage(config: StorageConfig) -> Storage: