Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th
#### Fields

- `storage` **StorageConfig**
- `type` **file|blob|cosmosdb** - The storage type to use. Default=`file`
- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file`
- `base_dir` **str** - The base directory to write output artifacts to, relative to the root.
- `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string.
- `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name.
Expand Down
4 changes: 2 additions & 2 deletions packages/graphrag-common/graphrag_common/factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

"""The GraphRAG factory module."""

from graphrag_common.factory.factory import Factory
from graphrag_common.factory.factory import Factory, ServiceScope

__all__ = ["Factory"]
__all__ = ["Factory", "ServiceScope"]
5 changes: 4 additions & 1 deletion packages/graphrag-common/graphrag_common/factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,14 @@ def create(self, strategy: str, init_args: dict[str, Any] | None = None) -> T:
msg = f"Strategy '{strategy}' is not registered. Registered strategies are: {', '.join(list(self._service_initializers.keys()))}"
raise ValueError(msg)

# Delete entries with value None
init_args = {k: v for k, v in (init_args or {}).items() if v is not None}

service_descriptor = self._service_initializers[strategy]
if service_descriptor.scope == "singleton":
if strategy not in self._initialized_services:
self._initialized_services[strategy] = service_descriptor.initializer(
**(init_args or {})
**init_args
)
return self._initialized_services[strategy]

Expand Down
86 changes: 86 additions & 0 deletions packages/graphrag-storage/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# GraphRAG Storage

## Basic

```python
import asyncio
from graphrag_storage import StorageConfig, create_storage, StorageType

async def run():
storage = create_storage(
StorageConfig(
type=StorageType.File
base_dir="output"
)
)

await storage.set("my_key", "value")
print(await storage.get("my_key"))

if __name__ == "__main__":
asyncio.run(run())
```

## Custom Storage

```python
import asyncio
from typing import Any
from graphrag_storage import Storage, StorageConfig, create_storage, register_storage

class MyStorage(Storage):
def __init__(self, some_setting: str, optional_setting: str = "default setting", **kwargs: Any):
# Validate settings and initialize
...

#Implement rest of interface
...

register_storage("MyStorage", MyStorage)

async def run():
storage = create_storage(
StorageConfig(
type="MyStorage"
some_setting="My Setting"
)
)
# Or use the factory directly to instantiate with a dict instead of using
# StorageConfig + create_factory
# from graphrag_storage.storage_factory import storage_factory
# storage = storage_factory.create(strategy="MyStorage", init_args={"some_setting": "My Setting"})

await storage.set("my_key", "value")
print(await storage.get("my_key"))

if __name__ == "__main__":
asyncio.run(run())
```

### Details

By default, the `create_storage` comes with the following storage providers registered that correspond to the entries in the `StorageType` enum.

- `FileStorage`
- `AzureBlobStorage`
- `AzureCosmosStorage`
- `MemoryStorage`

The preregistration happens dynamically, e.g., `FileStorage` is only imported and registered if you request a `FileStorage` with `create_storage(StorageType.File, ...)`. There is no need to manually import and register builtin storage providers when using `create_storage`.

If you want a clean factory with no preregistered storage providers then directly import `storage_factory` and bypass using `create_storage`. The downside is that `storage_factory.create` uses a dict for init args instead of the strongly typed `StorageConfig` used with `create_storage`.

```python
from graphrag_storage.storage_factory import storage_factory
from graphrag_storage.file_storage import FileStorage

# storage_factory has no preregistered providers so you must register any
# providers you plan on using.
# May also register a custom implementation, see above for example.
storage_factory.register("my_storage_key", FileStorage)

storage = storage_factory.create(strategy="my_storage_key", init_args={"base_dir": "...", "other_settings": "..."})

...

```
20 changes: 20 additions & 0 deletions packages/graphrag-storage/graphrag_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The GraphRAG Storage package."""

from graphrag_storage.storage import Storage
from graphrag_storage.storage_config import StorageConfig
from graphrag_storage.storage_factory import (
create_storage,
register_storage,
)
from graphrag_storage.storage_type import StorageType

__all__ = [
"Storage",
"StorageConfig",
"StorageType",
"create_storage",
"register_storage",
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Azure Blob Storage implementation of PipelineStorage."""
"""Azure Blob Storage implementation of Storage."""

import logging
import re
Expand All @@ -12,61 +12,68 @@
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient

from graphrag.storage.pipeline_storage import (
PipelineStorage,
from graphrag_storage.storage import (
Storage,
get_timestamp_formatted_with_local_tz,
)

logger = logging.getLogger(__name__)


class BlobPipelineStorage(PipelineStorage):
class AzureBlobStorage(Storage):
"""The Blob-Storage implementation."""

_connection_string: str | None
_container_name: str
_base_dir: str | None
_encoding: str
_storage_account_blob_url: str | None
_blob_service_client: BlobServiceClient
_storage_account_name: str | None

def __init__(self, **kwargs: Any) -> None:
def __init__(
self,
container_name: str,
account_url: str | None = None,
connection_string: str | None = None,
base_dir: str | None = None,
encoding: str = "utf-8",
**kwargs: Any,
) -> None:
"""Create a new BlobStorage instance."""
connection_string = kwargs.get("connection_string")
storage_account_blob_url = kwargs.get("storage_account_blob_url")
base_dir = kwargs.get("base_dir")
container_name = kwargs["container_name"]
if container_name is None:
msg = "No container name provided for blob storage."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "No storage account blob url provided for blob storage."
if connection_string is not None and account_url is not None:
msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both."
logger.error(msg)
raise ValueError(msg)

_validate_blob_container_name(container_name)

logger.info(
"Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir
"Creating blob storage at [%s] and base_dir [%s]",
container_name,
base_dir,
)
if connection_string:
self._blob_service_client = BlobServiceClient.from_connection_string(
connection_string
)
else:
if storage_account_blob_url is None:
msg = "Either connection_string or storage_account_blob_url must be provided."
raise ValueError(msg)

elif account_url:
self._blob_service_client = BlobServiceClient(
account_url=storage_account_blob_url,
account_url=account_url,
credential=DefaultAzureCredential(),
)
self._encoding = kwargs.get("encoding", "utf-8")
else:
msg = "AzureBlobStorage requires either a connection_string or storage_account_blob_url to be specified."
logger.error(msg)
raise ValueError(msg)

self._encoding = encoding
self._container_name = container_name
self._connection_string = connection_string
self._base_dir = base_dir
self._storage_account_blob_url = storage_account_blob_url
self._storage_account_blob_url = account_url
self._storage_account_name = (
storage_account_blob_url.split("//")[1].split(".")[0]
if storage_account_blob_url
else None
account_url.split("//")[1].split(".")[0] if account_url else None
)
self._create_container()

Expand Down Expand Up @@ -208,17 +215,17 @@ async def delete(self, key: str) -> None:
async def clear(self) -> None:
"""Clear the cache."""

def child(self, name: str | None) -> "PipelineStorage":
def child(self, name: str | None) -> "Storage":
"""Create a child storage instance."""
if name is None:
return self
path = str(Path(self._base_dir) / name) if self._base_dir else name
return BlobPipelineStorage(
return AzureBlobStorage(
connection_string=self._connection_string,
container_name=self._container_name,
encoding=self._encoding,
base_dir=path,
storage_account_blob_url=self._storage_account_blob_url,
account_url=self._storage_account_blob_url,
)

def keys(self) -> list[str]:
Expand All @@ -245,7 +252,7 @@ async def get_creation_date(self, key: str) -> str:
return ""


def validate_blob_container_name(container_name: str):
def _validate_blob_container_name(container_name: str) -> None:
Comment thread
dworthen marked this conversation as resolved.
"""
Check if the provided blob container name is valid based on Azure rules.

Expand All @@ -265,34 +272,7 @@ def validate_blob_container_name(container_name: str):
-------
bool: True if valid, False otherwise.
"""
# Check the length of the name
if len(container_name) < 3 or len(container_name) > 63:
return ValueError(
f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long."
)

# Check if the name starts with a letter or number
if not container_name[0].isalnum():
return ValueError(
f"Container name must start with a letter or number. Starting character was {container_name[0]}."
)

# Check for valid characters (letters, numbers, hyphen) and lowercase letters
if not re.match(r"^[a-z0-9-]+$", container_name):
return ValueError(
f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}."
)

# Check for consecutive hyphens
if "--" in container_name:
return ValueError(
f"Container name cannot contain consecutive hyphens. Name provided was {container_name}."
)

# Check for hyphens at the end of the name
if container_name[-1] == "-":
return ValueError(
f"Container name cannot end with a hyphen. Name provided was {container_name}."
)

return True
# Match alphanumeric or single hyphen not at the start or end, repeated 3-63 times.
if not re.match(r"^(?:[0-9a-z]|(?<!^)-(?!$)){3,63}$", container_name):
msg = f"Container name must be between 3 and 63 characters long and contain only lowercase letters, numbers, or hyphens. Name provided was {container_name}."
raise ValueError(msg)
Loading