Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250331184323312702.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "add vector store integration tests"
}
87 changes: 71 additions & 16 deletions graphrag/vector_stores/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy
from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.cosmos.partition_key import PartitionKey
from azure.identity import DefaultAzureCredential

Expand All @@ -19,7 +20,7 @@
)


class CosmosDBVectoreStore(BaseVectorStore):
class CosmosDBVectorStore(BaseVectorStore):
"""Azure CosmosDB vector storage implementation."""

_cosmos_client: CosmosClient
Expand Down Expand Up @@ -99,16 +100,32 @@ def _create_container(self) -> None:
"automatic": True,
"includedPaths": [{"path": "/*"}],
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
"vectorIndexes": [{"path": "/vector", "type": "diskANN"}],
}

# Create the container and container client
self._database_client.create_container_if_not_exists(
id=self._container_name,
partition_key=partition_key,
indexing_policy=indexing_policy,
vector_embedding_policy=vector_embedding_policy,
)
# Currently, the CosmosDB emulator does not support the diskANN policy.
try:
# First try with the standard diskANN policy
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]

# Create the container and container client
self._database_client.create_container_if_not_exists(
id=self._container_name,
partition_key=partition_key,
indexing_policy=indexing_policy,
vector_embedding_policy=vector_embedding_policy,
)
except CosmosHttpResponseError:
# If diskANN fails (likely in emulator), retry without vector indexes
indexing_policy.pop("vectorIndexes", None)

# Create the container with compatible indexing policy
self._database_client.create_container_if_not_exists(
id=self._container_name,
partition_key=partition_key,
indexing_policy=indexing_policy,
vector_embedding_policy=vector_embedding_policy,
)

self._container_client = self._database_client.get_container_client(
self._container_name
)
Expand Down Expand Up @@ -157,13 +174,46 @@ def similarity_search_by_vector(
msg = "Container client is not initialized."
raise ValueError(msg)

query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
query_params = [{"name": "@embedding", "value": query_embedding}]
items = self._container_client.query_items(
query=query,
parameters=query_params,
enable_cross_partition_query=True,
)
try:
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
query_params = [{"name": "@embedding", "value": query_embedding}]
items = list(
self._container_client.query_items(
query=query,
parameters=query_params,
enable_cross_partition_query=True,
)
)
except (CosmosHttpResponseError, ValueError):
# Currently, the CosmosDB emulator does not support the VectorDistance function.
# For emulator or test environments - fetch all items and calculate distance locally
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
items = list(
self._container_client.query_items(
query=query,
enable_cross_partition_query=True,
)
)

# Calculate cosine similarity locally (1 - cosine distance)
from numpy import dot
from numpy.linalg import norm

def cosine_similarity(a, b):
if norm(a) * norm(b) == 0:
return 0.0
return dot(a, b) / (norm(a) * norm(b))

# Calculate scores for all items
for item in items:
item_vector = item.get("vector", [])
similarity = cosine_similarity(query_embedding, item_vector)
item["SimilarityScore"] = similarity

# Sort by similarity score (higher is better) and take top k
items = sorted(
items, key=lambda x: x.get("SimilarityScore", 0.0), reverse=True
)[:k]

return [
VectorStoreSearchResult(
Expand Down Expand Up @@ -214,3 +264,8 @@ def search_by_id(self, id: str) -> VectorStoreDocument:
text=item.get("text", ""),
attributes=(json.loads(item.get("attributes", "{}"))),
)

def clear(self) -> None:
"""Clear the vector store."""
self._delete_container()
self._delete_database()
4 changes: 2 additions & 2 deletions graphrag/vector_stores/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore


Expand Down Expand Up @@ -44,7 +44,7 @@ def create_vector_store(
case VectorStoreType.AzureAISearch:
return AzureAISearchVectorStore(**kwargs)
case VectorStoreType.CosmosDB:
return CosmosDBVectoreStore(**kwargs)
return CosmosDBVectorStore(**kwargs)
case _:
if vector_store_type in cls.vector_store_types:
return cls.vector_store_types[vector_store_type](**kwargs)
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/vector_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Integration tests for vector store implementations."""
146 changes: 146 additions & 0 deletions tests/integration/vector_stores/test_azure_ai_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Integration tests for Azure AI Search vector store implementation."""

import os
from unittest.mock import MagicMock, patch

import pytest

from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import VectorStoreDocument

TEST_AZURE_AI_SEARCH_URL = os.environ.get(
"TEST_AZURE_AI_SEARCH_URL", "https://test-url.search.windows.net"
)
TEST_AZURE_AI_SEARCH_KEY = os.environ.get("TEST_AZURE_AI_SEARCH_KEY", "test_api_key")


class TestAzureAISearchVectorStore:
"""Test class for AzureAISearchVectorStore."""

@pytest.fixture
def mock_search_client(self):
"""Create a mock Azure AI Search client."""
with patch(
"graphrag.vector_stores.azure_ai_search.SearchClient"
) as mock_client:
yield mock_client.return_value

@pytest.fixture
def mock_index_client(self):
"""Create a mock Azure AI Search index client."""
with patch(
"graphrag.vector_stores.azure_ai_search.SearchIndexClient"
) as mock_client:
yield mock_client.return_value

@pytest.fixture
def vector_store(self, mock_search_client, mock_index_client):
"""Create an Azure AI Search vector store instance."""
vector_store = AzureAISearchVectorStore(collection_name="test_vectors")

# Create the necessary mocks first
vector_store.db_connection = mock_search_client
vector_store.index_client = mock_index_client

vector_store.connect(
url=TEST_AZURE_AI_SEARCH_URL,
api_key=TEST_AZURE_AI_SEARCH_KEY,
vector_size=5,
)
return vector_store

@pytest.fixture
def sample_documents(self):
"""Create sample documents for testing."""
return [
VectorStoreDocument(
id="doc1",
text="This is document 1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Doc 1", "category": "test"},
),
VectorStoreDocument(
id="doc2",
text="This is document 2",
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
attributes={"title": "Doc 2", "category": "test"},
),
]

async def test_vector_store_operations(
self, vector_store, sample_documents, mock_search_client, mock_index_client
):
"""Test basic vector store operations with Azure AI Search."""
# Setup mock responses
mock_index_client.list_index_names.return_value = []
mock_index_client.create_or_update_index = MagicMock()
mock_search_client.upload_documents = MagicMock()

search_results = [
{
"id": "doc1",
"text": "This is document 1",
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
"attributes": '{"title": "Doc 1", "category": "test"}',
"@search.score": 0.9,
},
{
"id": "doc2",
"text": "This is document 2",
"vector": [0.2, 0.3, 0.4, 0.5, 0.6],
"attributes": '{"title": "Doc 2", "category": "test"}',
"@search.score": 0.8,
},
]
mock_search_client.search.return_value = search_results

mock_search_client.get_document.return_value = {
"id": "doc1",
"text": "This is document 1",
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
"attributes": '{"title": "Doc 1", "category": "test"}',
}

vector_store.load_documents(sample_documents)
assert mock_index_client.create_or_update_index.called
assert mock_search_client.upload_documents.called

filter_query = vector_store.filter_by_id(["doc1", "doc2"])
assert filter_query == "search.in(id, 'doc1,doc2', ',')"

vector_results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert len(vector_results) == 2
assert vector_results[0].document.id == "doc1"
assert vector_results[0].score == 0.9

# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]

text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert len(text_results) == 2

doc = vector_store.search_by_id("doc1")
assert doc.id == "doc1"
assert doc.text == "This is document 1"
assert doc.attributes["title"] == "Doc 1"

async def test_empty_embedding(self, vector_store, mock_search_client):
"""Test similarity search by text with empty embedding."""

# Create a mock embedder that returns None and verify that no results are produced
def none_embedder(text: str) -> None:
return None

results = vector_store.similarity_search_by_text(
"test query", none_embedder, k=1
)
assert not mock_search_client.search.called
assert len(results) == 0
Loading
Loading