Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260227202720480258.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "vector load_documents in batches"
}
1 change: 1 addition & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ retryer
agenerate
dropna
notna
upserted

# LLM Terms
AOAI
Expand Down
19 changes: 12 additions & 7 deletions packages/graphrag-vectors/graphrag_vectors/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,22 +165,27 @@ def create_index(self) -> None:
index,
)

def insert(self, document: VectorStoreDocument) -> None:
"""Insert a single document into Azure AI Search."""
self._prepare_document(document)
if document.vector is not None:
doc_dict = {
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into Azure AI Search as a single batch upload."""
batch: list[dict[str, Any]] = []
for document in documents:
self._prepare_document(document)
if document.vector is None:
continue
doc_dict: dict[str, Any] = {
self.id_field: document.id,
self.vector_field: document.vector,
self.create_date_field: document.create_date,
self.update_date_field: document.update_date,
}
# Add additional fields if they exist in the document data
if document.data:
for field_name in self.fields:
if field_name in document.data:
doc_dict[field_name] = document.data[field_name]
self.db_connection.upload_documents([doc_dict])
batch.append(doc_dict)

if batch:
self.db_connection.upload_documents(batch)

def _compile_filter(self, expr: FilterExpr) -> str:
"""Compile a FilterExpr into an Azure AI Search OData filter string."""
Expand Down
15 changes: 10 additions & 5 deletions packages/graphrag-vectors/graphrag_vectors/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,22 @@ def create_index(self) -> None:
msg = "Container client is not initialized."
raise ValueError(msg)

def insert(self, document: VectorStoreDocument) -> None:
"""Insert a single document into CosmosDB."""
self._prepare_document(document)
if document.vector is not None:
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into CosmosDB.

CosmosDB does not support native batch upsert, so each
document is upserted individually after preparation.
"""
for document in documents:
self._prepare_document(document)
if document.vector is None:
continue
doc_json: dict[str, Any] = {
self.id_field: document.id,
self.vector_field: document.vector,
self.create_date_field: document.create_date,
self.update_date_field: document.update_date,
}
# Add additional fields if they exist in the document data
if document.data:
for field_name in self.fields:
if field_name in document.data:
Expand Down
65 changes: 35 additions & 30 deletions packages/graphrag-vectors/graphrag_vectors/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,38 +78,43 @@ def create_index(self) -> None:
# Remove the dummy document used to set up the schema
self.document_collection.delete(f"{self.id_field} = '__DUMMY__'")

def insert(self, document: VectorStoreDocument) -> None:
"""Insert a single document into LanceDB."""
self._prepare_document(document)
if document.vector is not None:
vector = np.array(document.vector, dtype=np.float32)
flat_array = pa.array(vector, type=pa.float32())
vector_column = pa.FixedSizeListArray.from_arrays(
flat_array, self.vector_size
)

others = {}
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into LanceDB as a single batch write."""
ids: list[str] = []
vectors: list[np.ndarray] = []
create_dates: list[str | None] = []
update_dates: list[str | None] = []
field_columns: dict[str, list[Any]] = {name: [] for name in self.fields}

for document in documents:
self._prepare_document(document)
if document.vector is None:
continue

ids.append(str(document.id))
vectors.append(np.array(document.vector, dtype=np.float32))
create_dates.append(document.create_date)
update_dates.append(document.update_date)
for field_name in self.fields:
others[field_name] = (
document.data.get(field_name) if document.data else None
)

data = pa.table({
self.id_field: pa.array([document.id], type=pa.string()),
self.vector_field: vector_column,
self.create_date_field: pa.array(
[document.create_date], type=pa.string()
),
self.update_date_field: pa.array(
[document.update_date], type=pa.string()
),
**{
field_name: pa.array([value])
for field_name, value in others.items()
},
})
value = document.data.get(field_name) if document.data else None
field_columns[field_name].append(value)

if not ids:
return

flat_vector = np.concatenate(vectors).astype(np.float32)
flat_array = pa.array(flat_vector, type=pa.float32())
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)

data = pa.table({
self.id_field: pa.array(ids, type=pa.string()),
self.vector_field: vector_column,
self.create_date_field: pa.array(create_dates, type=pa.string()),
self.update_date_field: pa.array(update_dates, type=pa.string()),
**{name: pa.array(values) for name, values in field_columns.items()},
})

self.document_collection.add(data)
self.document_collection.add(data)

def _extract_data(
self, doc: dict[str, Any], select: list[str] | None = None
Expand Down
7 changes: 3 additions & 4 deletions packages/graphrag-vectors/graphrag_vectors/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,13 @@ def connect(self) -> None:
def create_index(self) -> None:
"""Create index."""

@abstractmethod
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into the vector-store."""
for doc in documents:
self.insert(doc)

@abstractmethod
def insert(self, document: VectorStoreDocument) -> None:
"""Insert a single document into the vector-store."""
"""Insert a single document by delegating to load_documents."""
self.load_documents([document])

@abstractmethod
def similarity_search_by_vector(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def embed_text(

buffer: list[dict[str, Any]] = []
total_rows = 0
flush_size = batch_size * 4

async for row in input_table:
text = row.get(embed_column)
Expand All @@ -49,7 +50,7 @@ async def embed_text(
embed_column: text,
})

if len(buffer) >= batch_size:
if len(buffer) >= flush_size:
total_rows += await _flush_embedding_buffer(
buffer,
embed_column,
Expand Down
13 changes: 6 additions & 7 deletions tests/unit/indexing/operations/embed_text/test_embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ async def test_embed_text_basic():

@pytest.mark.asyncio
async def test_embed_text_batching():
"""Verify rows are flushed in batches when batch_size < total rows."""
rows = [{"id": str(i), "text": f"text {i}"} for i in range(5)]
"""Verify rows are flushed in batches when buffer exceeds batch_size * 4."""
rows = [{"id": str(i), "text": f"text {i}"} for i in range(10)]
input_table = FakeInputTable(rows)
vector_store = _make_mock_vector_store()

Expand All @@ -160,9 +160,8 @@ async def test_embed_text_batching():
new_callable=AsyncMock,
) as mock_run:
mock_run.side_effect = [
_make_embedding_result(2, [1.0]),
_make_embedding_result(8, [1.0]),
_make_embedding_result(2, [2.0]),
_make_embedding_result(1, [3.0]),
]

count = await embed_text(
Expand All @@ -177,9 +176,9 @@ async def test_embed_text_batching():
vector_store=vector_store,
)

assert count == 5
assert mock_run.call_count == 3
assert vector_store.load_documents.call_count == 3
assert count == 10
assert mock_run.call_count == 2
assert vector_store.load_documents.call_count == 2


@pytest.mark.asyncio
Expand Down