From f23cb101bc1acea2eba606afd753cef3af7d7d38 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Fri, 20 Feb 2026 18:47:15 -0300 Subject: [PATCH 1/2] addd streaming --- .../patch-20260220214632816094.json | 4 + .../index_migration_to_v1.ipynb | 39 +- .../graphrag/data_model/row_transformers.py | 10 + .../index/operations/embed_text/embed_text.py | 147 +++++-- .../workflows/generate_text_embeddings.py | 234 +++++----- .../index/workflows/update_text_embeddings.py | 50 +-- tests/unit/indexing/operations/__init__.py | 2 + .../operations/embed_text/__init__.py | 2 + .../operations/embed_text/test_embed_text.py | 406 ++++++++++++++++++ tests/verbs/test_update_text_embeddings.py | 65 +++ 10 files changed, 716 insertions(+), 243 deletions(-) create mode 100644 .semversioner/next-release/patch-20260220214632816094.json create mode 100644 tests/unit/indexing/operations/__init__.py create mode 100644 tests/unit/indexing/operations/embed_text/__init__.py create mode 100644 tests/unit/indexing/operations/embed_text/test_embed_text.py create mode 100644 tests/verbs/test_update_text_embeddings.py diff --git a/.semversioner/next-release/patch-20260220214632816094.json b/.semversioner/next-release/patch-20260220214632816094.json new file mode 100644 index 0000000000..925bd32f8f --- /dev/null +++ b/.semversioner/next-release/patch-20260220214632816094.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "generate_text_embeddings streaming" +} diff --git a/docs/examples_notebooks/index_migration_to_v1.ipynb b/docs/examples_notebooks/index_migration_to_v1.ipynb index 3fd8f264bc..588c025d8b 100644 --- a/docs/examples_notebooks/index_migration_to_v1.ipynb +++ b/docs/examples_notebooks/index_migration_to_v1.ipynb @@ -205,44 +205,23 @@ "metadata": {}, "outputs": [], "source": [ - "from graphrag.cache.factory import CacheFactory\n", "from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n", "from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings\n", - "from graphrag.language_model.manager import ModelManager\n", - "from graphrag.tokenizer.get_tokenizer import get_tokenizer\n", + "from graphrag_cache import create_cache\n", "\n", - "# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n", - "# We'll construct the context and run this function flow directly to avoid everything else\n", + "# We only need to re-run the embeddings workflow, to ensure that embeddings\n", + "# for all required search fields are in place.\n", + "# We pass in the table_provider created earlier so that generate_text_embeddings\n", + "# reads the migrated tables we just wrote.\n", "\n", - "model_config = config.get_language_model_config(config.embed_text.model_id)\n", "callbacks = NoopWorkflowCallbacks()\n", - "cache_config = config.cache.model_dump() # type: ignore\n", - "cache = CacheFactory().create_cache(\n", - " cache_type=cache_config[\"type\"], # type: ignore\n", - " **cache_config,\n", - ")\n", - "model = ModelManager().get_or_create_embedding_model(\n", - " name=\"text_embedding\",\n", - " model_type=model_config.type,\n", - " config=model_config,\n", - " callbacks=callbacks,\n", - " cache=cache,\n", - ")\n", - "\n", - "tokenizer = get_tokenizer(model_config)\n", + "cache = create_cache(config.cache)\n", "\n", "await generate_text_embeddings(\n", - " text_units=final_text_units,\n", - " entities=final_entities,\n", - " community_reports=final_community_reports,\n", + " config=config,\n", + " table_provider=table_provider,\n", + " cache=cache,\n", " callbacks=callbacks,\n", - " model=model,\n", - " tokenizer=tokenizer,\n", - " batch_size=config.embed_text.batch_size,\n", - " batch_max_tokens=config.embed_text.batch_max_tokens,\n", - " num_threads=model_config.concurrent_requests,\n", - " vector_store_config=config.vector_store,\n", - " embedded_fields=config.embed_text.names,\n", ")" ] } diff --git a/packages/graphrag/graphrag/data_model/row_transformers.py b/packages/graphrag/graphrag/data_model/row_transformers.py index d767bd98b9..f26858407b 100644 --- a/packages/graphrag/graphrag/data_model/row_transformers.py +++ b/packages/graphrag/graphrag/data_model/row_transformers.py @@ -89,6 +89,16 @@ def transform_entity_row(row: dict[str, Any]) -> dict[str, Any]: return row +def transform_entity_row_for_embedding( + row: dict[str, Any], +) -> dict[str, Any]: + """Add a title_description column for embedding generation.""" + title = row.get("title") or "" + description = row.get("description") or "" + row["title_description"] = f"{title}:{description}" + return row + + # -- relationships (mirrors relationships_typed) -------------------------- diff --git a/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py b/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py index af0b79bbb5..72ba2a2ec1 100644 --- a/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py +++ b/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py @@ -1,14 +1,14 @@ -# Copyright (c) 2024 Microsoft Corporation. +# Copyright (C) 2026 Microsoft # Licensed under the MIT License -"""A module containing embed_text method definition.""" +"""Streaming text embedding operation.""" import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np -import pandas as pd from graphrag_llm.tokenizer import Tokenizer +from graphrag_storage.tables.table import Table from graphrag_vectors import VectorStore, VectorStoreDocument from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -21,7 +21,7 @@ async def embed_text( - input: pd.DataFrame, + input_table: Table, callbacks: WorkflowCallbacks, model: "LLMEmbedding", tokenizer: Tokenizer, @@ -31,59 +31,116 @@ async def embed_text( num_threads: int, vector_store: VectorStore, id_column: str = "id", -): - """Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.""" - if embed_column not in input.columns: - msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}" - raise ValueError(msg) - if id_column not in input.columns: - msg = f"Column {id_column} not found in input dataframe with columns {input.columns}" - raise ValueError(msg) - + output_table: Table | None = None, +) -> int: + """Embed text from a streaming Table into a vector store.""" vector_store.create_index() - index = 0 + buffer: list[dict[str, Any]] = [] + total_rows = 0 - all_results = [] + async for row in input_table: + text = row.get(embed_column) + if text is None: + text = "" - num_total_batches = (input.shape[0] + batch_size - 1) // batch_size - while batch_size * index < input.shape[0]: - logger.info( - "uploading text embeddings batch %d/%d of size %d to vector store", - index + 1, - num_total_batches, - batch_size, - ) - batch = input.iloc[batch_size * index : batch_size * (index + 1)] - texts: list[str] = batch[embed_column].tolist() - ids: list[str] = batch[id_column].tolist() - result = await run_embed_text( - texts, + buffer.append({ + id_column: row[id_column], + embed_column: text, + }) + + if len(buffer) >= batch_size: + total_rows += await _flush_embedding_buffer( + buffer, + embed_column, + id_column, + callbacks, + model, + tokenizer, + batch_size, + batch_max_tokens, + num_threads, + vector_store, + output_table, + ) + buffer.clear() + + if buffer: + total_rows += await _flush_embedding_buffer( + buffer, + embed_column, + id_column, callbacks, model, tokenizer, batch_size, batch_max_tokens, num_threads, + vector_store, + output_table, ) - if result.embeddings: - embeddings = [ - embedding for embedding in result.embeddings if embedding is not None - ] - all_results.extend(embeddings) - - vectors = result.embeddings or [] - documents: list[VectorStoreDocument] = [] - for doc_id, doc_vector in zip(ids, vectors, strict=True): - if type(doc_vector) is np.ndarray: - doc_vector = doc_vector.tolist() - document = VectorStoreDocument( + + return total_rows + + +async def _flush_embedding_buffer( + buffer: list[dict[str, Any]], + embed_column: str, + id_column: str, + callbacks: WorkflowCallbacks, + model: "LLMEmbedding", + tokenizer: Tokenizer, + batch_size: int, + batch_max_tokens: int, + num_threads: int, + vector_store: VectorStore, + output_table: Table | None, +) -> int: + """Embed a buffer of rows and load results into the vector store.""" + texts: list[str] = [row[embed_column] for row in buffer] + ids: list[str] = [row[id_column] for row in buffer] + + result = await run_embed_text( + texts, + callbacks, + model, + tokenizer, + batch_size, + batch_max_tokens, + num_threads, + ) + + vectors = result.embeddings or [] + skipped = 0 + documents: list[VectorStoreDocument] = [] + for doc_id, doc_vector in zip(ids, vectors, strict=True): + if doc_vector is None: + skipped += 1 + continue + if type(doc_vector) is np.ndarray: + doc_vector = doc_vector.tolist() + documents.append( + VectorStoreDocument( id=doc_id, vector=doc_vector, ) - documents.append(document) + ) + + vector_store.load_documents(documents) + + if skipped > 0: + logger.warning( + "Skipped %d rows with None embeddings out of %d", + skipped, + len(buffer), + ) - vector_store.load_documents(documents) - index += 1 + if output_table is not None: + for doc_id, doc_vector in zip(ids, vectors, strict=True): + if doc_vector is None: + continue + if type(doc_vector) is np.ndarray: + doc_vector = doc_vector.tolist() + await output_table.write({"id": doc_id, "embedding": doc_vector}) - return all_results + return len(buffer) diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index a78de2bb58..9f5965ef36 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -4,15 +4,13 @@ """A module containing run_workflow method definition.""" import logging +from contextlib import AsyncExitStack +from dataclasses import dataclass from typing import TYPE_CHECKING -import pandas as pd from graphrag_llm.embedding import create_embedding -from graphrag_llm.tokenizer import Tokenizer -from graphrag_vectors import ( - VectorStoreConfig, - create_vector_store, -) +from graphrag_storage.tables.table import RowTransformer +from graphrag_vectors import create_vector_store from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -22,160 +20,140 @@ text_unit_text_embedding, ) from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.data_model.data_reader import DataReader +from graphrag.data_model.row_transformers import ( + transform_entity_row_for_embedding, +) from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput if TYPE_CHECKING: - from graphrag_llm.embedding import LLMEmbedding + from graphrag_cache import Cache + from graphrag_storage.tables.table_provider import TableProvider logger = logging.getLogger(__name__) +@dataclass +class EmbeddingFieldConfig: + """Configuration for a single embedding field. + + Describes which source table and column to embed, and an + optional row transform to apply before embedding. + """ + + name: str + table_name: str + embed_column: str + row_transform: RowTransformer | None = None + + +EMBEDDING_FIELDS: dict[str, EmbeddingFieldConfig] = { + text_unit_text_embedding: EmbeddingFieldConfig( + name=text_unit_text_embedding, + table_name="text_units", + embed_column="text", + ), + entity_description_embedding: EmbeddingFieldConfig( + name=entity_description_embedding, + table_name="entities", + embed_column="title_description", + row_transform=transform_entity_row_for_embedding, + ), + community_full_content_embedding: EmbeddingFieldConfig( + name=community_full_content_embedding, + table_name="community_reports", + embed_column="full_content", + ), +} + + async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, ) -> WorkflowFunctionOutput: - """All the steps to transform community reports.""" + """Generate text embeddings for configured fields via streaming Tables.""" logger.info("Workflow started: generate_text_embeddings") + + await generate_text_embeddings( + config=config, + table_provider=context.output_table_provider, + cache=context.cache, + callbacks=context.callbacks, + ) + + logger.info("Workflow completed: generate_text_embeddings") + return WorkflowFunctionOutput(result=None) + + +async def generate_text_embeddings( + config: GraphRagConfig, + table_provider: TableProvider, + cache: Cache, + callbacks: WorkflowCallbacks, +) -> None: + """Generate text embeddings for all configured fields.""" embedded_fields = config.embed_text.names logger.info("Embedding the following fields: %s", embedded_fields) - reader = DataReader(context.output_table_provider) - text_units = None - entities = None - community_reports = None - if text_unit_text_embedding in embedded_fields: - text_units = await reader.text_units() - if entity_description_embedding in embedded_fields: - entities = await reader.entities() - if community_full_content_embedding in embedded_fields: - community_reports = await reader.community_reports() model_config = config.get_embedding_model_config( config.embed_text.embedding_model_id ) - model = create_embedding( model_config, - cache=context.cache.child(config.embed_text.model_instance_name), + cache=cache.child(config.embed_text.model_instance_name), cache_key_creator=cache_key_creator, ) - tokenizer = model.tokenizer - output = await generate_text_embeddings( - text_units=text_units, - entities=entities, - community_reports=community_reports, - callbacks=context.callbacks, - model=model, - tokenizer=tokenizer, - batch_size=config.embed_text.batch_size, - batch_max_tokens=config.embed_text.batch_max_tokens, - num_threads=config.concurrent_requests, - vector_store_config=config.vector_store, - embedded_fields=embedded_fields, - ) + for field_name in embedded_fields: + field_config = EMBEDDING_FIELDS[field_name] - if config.snapshots.embeddings: - for name, table in output.items(): - await context.output_table_provider.write_dataframe( - f"embeddings.{name}", - table, + if not await table_provider.has(field_config.table_name): + logger.warning( + "Embedding %s is specified but source table '%s' " + "is not in storage. Skipping.", + field_config.name, + field_config.table_name, + ) + continue + + vector_store = create_vector_store( + config.vector_store, + config.vector_store.index_schema[field_config.name], + ) + vector_store.connect() + + async with AsyncExitStack() as stack: + input_table = await stack.enter_async_context( + table_provider.open( + field_config.table_name, + truncate=False, + transformer=field_config.row_transform, + ) ) - logger.info("Workflow completed: generate_text_embeddings") - return WorkflowFunctionOutput(result=output) - + output_table = None + if config.snapshots.embeddings: + output_table = await stack.enter_async_context( + table_provider.open(f"embeddings.{field_config.name}") + ) -async def generate_text_embeddings( - text_units: pd.DataFrame | None, - entities: pd.DataFrame | None, - community_reports: pd.DataFrame | None, - callbacks: WorkflowCallbacks, - model: "LLMEmbedding", - tokenizer: Tokenizer, - batch_size: int, - batch_max_tokens: int, - num_threads: int, - vector_store_config: VectorStoreConfig, - embedded_fields: list[str], -) -> dict[str, pd.DataFrame]: - """All the steps to generate all embeddings.""" - embedding_param_map = { - text_unit_text_embedding: { - "data": text_units.loc[:, ["id", "text"]].fillna("") - if text_units is not None - else None, - "embed_column": "text", - }, - entity_description_embedding: { - "data": entities - .loc[:, ["id", "title", "description"]] - .fillna("") - .assign(title_description=lambda df: df["title"] + ":" + df["description"]) - if entities is not None - else None, - "embed_column": "title_description", - }, - community_full_content_embedding: { - "data": community_reports.loc[:, ["id", "full_content"]].fillna("") - if community_reports is not None - else None, - "embed_column": "full_content", - }, - } - - logger.info("Creating embeddings") - outputs = {} - for field in embedded_fields: - if embedding_param_map[field]["data"] is None: - msg = f"Embedding {field} is specified but data table is not in storage. This may or may not be intentional - if you expect it to me here, please check for errors earlier in the logs." - logger.warning(msg) - else: - outputs[field] = await _run_embeddings( - name=field, + count = await embed_text( + input_table=input_table, callbacks=callbacks, model=model, tokenizer=tokenizer, - vector_store_config=vector_store_config, - batch_size=batch_size, - batch_max_tokens=batch_max_tokens, - num_threads=num_threads, - **embedding_param_map[field], + embed_column=field_config.embed_column, + batch_size=config.embed_text.batch_size, + batch_max_tokens=config.embed_text.batch_max_tokens, + num_threads=config.concurrent_requests, + vector_store=vector_store, + output_table=output_table, ) - return outputs - - -async def _run_embeddings( - name: str, - data: pd.DataFrame, - embed_column: str, - callbacks: WorkflowCallbacks, - model: "LLMEmbedding", - tokenizer: Tokenizer, - batch_size: int, - batch_max_tokens: int, - num_threads: int, - vector_store_config: VectorStoreConfig, -) -> pd.DataFrame: - """All the steps to generate single embedding.""" - vector_store = create_vector_store( - vector_store_config, vector_store_config.index_schema[name] - ) - vector_store.connect() - - data["embedding"] = await embed_text( - input=data, - callbacks=callbacks, - model=model, - tokenizer=tokenizer, - embed_column=embed_column, - batch_size=batch_size, - batch_max_tokens=batch_max_tokens, - num_threads=num_threads, - vector_store=vector_store, - ) - return data.loc[:, ["id", "embedding"]] + logger.info( + "Embedded %d rows for %s", + count, + field_config.name, + ) diff --git a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py index 4a3cf1a673..5ec12f60c1 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py @@ -1,18 +1,17 @@ -# Copyright (c) 2024 Microsoft Corporation. +# Copyright (C) 2026 Microsoft # Licensed under the MIT License """A module containing run_workflow method definition.""" import logging -from graphrag_llm.embedding import create_embedding - -from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings +from graphrag.index.workflows.generate_text_embeddings import ( + generate_text_embeddings, +) logger = logging.getLogger(__name__) @@ -21,48 +20,19 @@ async def run_workflow( config: GraphRagConfig, context: PipelineRunContext, ) -> WorkflowFunctionOutput: - """Update the text embeddings from a incremental index run.""" + """Update text embeddings for an incremental index run.""" logger.info("Workflow started: update_text_embeddings") + output_table_provider, _, _ = get_update_table_providers( config, context.state["update_timestamp"] ) - merged_text_units = context.state["incremental_update_merged_text_units"] - merged_entities_df = context.state["incremental_update_merged_entities"] - merged_community_reports = context.state[ - "incremental_update_merged_community_reports" - ] - - embedded_fields = config.embed_text.names - - model_config = config.get_embedding_model_config( - config.embed_text.embedding_model_id - ) - - model = create_embedding( - model_config, - cache=context.cache.child("text_embedding"), - cache_key_creator=cache_key_creator, - ) - - tokenizer = model.tokenizer - - result = await generate_text_embeddings( - text_units=merged_text_units, - entities=merged_entities_df, - community_reports=merged_community_reports, + await generate_text_embeddings( + config=config, + table_provider=output_table_provider, + cache=context.cache, callbacks=context.callbacks, - model=model, - tokenizer=tokenizer, - batch_size=config.embed_text.batch_size, - batch_max_tokens=config.embed_text.batch_max_tokens, - num_threads=config.concurrent_requests, - vector_store_config=config.vector_store, - embedded_fields=embedded_fields, ) - if config.snapshots.embeddings: - for name, table in result.items(): - await output_table_provider.write_dataframe(f"embeddings.{name}", table) logger.info("Workflow completed: update_text_embeddings") return WorkflowFunctionOutput(result=None) diff --git a/tests/unit/indexing/operations/__init__.py b/tests/unit/indexing/operations/__init__.py new file mode 100644 index 0000000000..05c6313ef3 --- /dev/null +++ b/tests/unit/indexing/operations/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2026 Microsoft +# Licensed under the MIT License diff --git a/tests/unit/indexing/operations/embed_text/__init__.py b/tests/unit/indexing/operations/embed_text/__init__.py new file mode 100644 index 0000000000..05c6313ef3 --- /dev/null +++ b/tests/unit/indexing/operations/embed_text/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2026 Microsoft +# Licensed under the MIT License diff --git a/tests/unit/indexing/operations/embed_text/test_embed_text.py b/tests/unit/indexing/operations/embed_text/test_embed_text.py new file mode 100644 index 0000000000..92ec0cfcf3 --- /dev/null +++ b/tests/unit/indexing/operations/embed_text/test_embed_text.py @@ -0,0 +1,406 @@ +# Copyright (C) 2026 Microsoft +# Licensed under the MIT License + +"""Unit tests for the streaming embed_text operation.""" + +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from graphrag.callbacks.noop_workflow_callbacks import ( + NoopWorkflowCallbacks, +) +from graphrag.index.operations.embed_text.embed_text import embed_text +from graphrag.index.operations.embed_text.run_embed_text import ( + TextEmbeddingResult, +) +from graphrag_storage.tables.table import Table + + +class FakeInputTable(Table): + """In-memory table that yields rows via async iteration.""" + + def __init__(self, rows: list[dict[str, Any]]) -> None: + """Store the rows to be yielded.""" + self._rows = rows + + def __aiter__(self) -> AsyncIterator[dict[str, Any]]: + """Return an async iterator yielding each stored row.""" + return self._iter() + + async def _iter(self) -> AsyncIterator[dict[str, Any]]: + """Yield rows one at a time.""" + for row in self._rows: + yield dict(row) + + async def length(self) -> int: + """Return the number of rows.""" + return len(self._rows) + + async def has(self, row_id: str) -> bool: + """Check if a row with the given ID exists.""" + return any(r.get("id") == row_id for r in self._rows) + + async def write(self, row: dict[str, Any]) -> None: + """No-op write (input table is read-only).""" + + async def close(self) -> None: + """No-op close.""" + + +class FakeOutputTable(Table): + """Collects rows written via write() for assertion.""" + + def __init__(self) -> None: + """Initialize empty row collection.""" + self.rows: list[dict[str, Any]] = [] + + def __aiter__(self) -> AsyncIterator[dict[str, Any]]: + """Yield collected rows.""" + return self._iter() + + async def _iter(self) -> AsyncIterator[dict[str, Any]]: + """Yield rows one at a time.""" + for row in self.rows: + yield row + + async def length(self) -> int: + """Return the number of written rows.""" + return len(self.rows) + + async def has(self, row_id: str) -> bool: + """Check if a row with the given ID was written.""" + return any(r.get("id") == row_id for r in self.rows) + + async def write(self, row: dict[str, Any]) -> None: + """Append a row to the collection.""" + self.rows.append(row) + + async def close(self) -> None: + """No-op close.""" + + +def _make_mock_vector_store(): + """Create a mock vector store with create_index and load_documents.""" + store = MagicMock() + store.create_index = MagicMock() + store.load_documents = MagicMock() + return store + + +def _make_mock_model(embedding_values: list[float]): + """Create a mock model that returns fixed embeddings.""" + model = MagicMock() + model.tokenizer = MagicMock() + return model, embedding_values + + +def _make_embedding_result(count: int, values: list[float]) -> TextEmbeddingResult: + """Build a TextEmbeddingResult with count copies of values.""" + return TextEmbeddingResult(embeddings=[list(values) for _ in range(count)]) + + +@pytest.mark.asyncio +async def test_embed_text_basic(): + """Verify basic embedding: rows flow through to vector store and output table.""" + rows = [ + {"id": "a", "text": "hello world"}, + {"id": "b", "text": "foo bar"}, + {"id": "c", "text": "baz qux"}, + ] + input_table = FakeInputTable(rows) + output_table = FakeOutputTable() + vector_store = _make_mock_vector_store() + embedding_values = [1.0, 2.0, 3.0] + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + mock_run.return_value = _make_embedding_result(3, embedding_values) + + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="text", + batch_size=10, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + output_table=output_table, + ) + + assert count == 3 + assert len(output_table.rows) == 3 + assert output_table.rows[0]["id"] == "a" + assert output_table.rows[0]["embedding"] == embedding_values + assert output_table.rows[2]["id"] == "c" + + vector_store.create_index.assert_called_once() + vector_store.load_documents.assert_called_once() + docs = vector_store.load_documents.call_args[0][0] + assert len(docs) == 3 + assert docs[0].id == "a" + assert docs[1].id == "b" + + +@pytest.mark.asyncio +async def test_embed_text_batching(): + """Verify rows are flushed in batches when batch_size < total rows.""" + rows = [{"id": str(i), "text": f"text {i}"} for i in range(5)] + input_table = FakeInputTable(rows) + vector_store = _make_mock_vector_store() + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + mock_run.side_effect = [ + _make_embedding_result(2, [1.0]), + _make_embedding_result(2, [2.0]), + _make_embedding_result(1, [3.0]), + ] + + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="text", + batch_size=2, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + ) + + assert count == 5 + assert mock_run.call_count == 3 + assert vector_store.load_documents.call_count == 3 + + +@pytest.mark.asyncio +async def test_embed_text_pretransformed_rows(): + """Verify rows pre-transformed by table layer are embedded correctly.""" + rows = [ + { + "id": "1", + "title": "Alpha", + "description": "First", + "combined": "Alpha:First", + }, + { + "id": "2", + "title": "Beta", + "description": "Second", + "combined": "Beta:Second", + }, + ] + input_table = FakeInputTable(rows) + output_table = FakeOutputTable() + vector_store = _make_mock_vector_store() + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + mock_run.return_value = _make_embedding_result(2, [0.5]) + + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="combined", + batch_size=10, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + output_table=output_table, + ) + + assert count == 2 + texts_arg = mock_run.call_args[0][0] + assert texts_arg == ["Alpha:First", "Beta:Second"] + + +@pytest.mark.asyncio +async def test_embed_text_none_values_filled(): + """Verify None embed_column values are replaced with empty string.""" + rows = [ + {"id": "1", "text": None}, + {"id": "2", "text": "real text"}, + ] + input_table = FakeInputTable(rows) + vector_store = _make_mock_vector_store() + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + mock_run.return_value = _make_embedding_result(2, [1.0]) + + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="text", + batch_size=10, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + ) + + assert count == 2 + texts_arg = mock_run.call_args[0][0] + assert texts_arg == ["", "real text"] + + +@pytest.mark.asyncio +async def test_embed_text_no_output_table(): + """Verify embedding works without an output table (no snapshot).""" + rows = [{"id": "x", "text": "data"}] + input_table = FakeInputTable(rows) + vector_store = _make_mock_vector_store() + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + mock_run.return_value = _make_embedding_result(1, [9.0]) + + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="text", + batch_size=10, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + output_table=None, + ) + + assert count == 1 + vector_store.load_documents.assert_called_once() + + +@pytest.mark.asyncio +async def test_embed_text_empty_input(): + """Verify zero rows returns zero count with no calls.""" + input_table = FakeInputTable([]) + vector_store = _make_mock_vector_store() + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="text", + batch_size=10, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + ) + + assert count == 0 + mock_run.assert_not_called() + vector_store.load_documents.assert_not_called() + + +@pytest.mark.asyncio +async def test_embed_text_numpy_array_vectors(): + """Verify np.ndarray embeddings are converted to plain lists.""" + rows = [ + {"id": "a", "text": "hello"}, + {"id": "b", "text": "world"}, + ] + input_table = FakeInputTable(rows) + output_table = FakeOutputTable() + vector_store = _make_mock_vector_store() + + numpy_embeddings = [np.array([1.0, 2.0]), np.array([3.0, 4.0])] + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + mock_run.return_value = TextEmbeddingResult(embeddings=numpy_embeddings) + + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="text", + batch_size=10, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + output_table=output_table, + ) + + assert count == 2 + + docs = vector_store.load_documents.call_args[0][0] + assert docs[0].vector == [1.0, 2.0] + assert docs[1].vector == [3.0, 4.0] + assert type(docs[0].vector) is list + assert type(docs[1].vector) is list + + assert output_table.rows[0]["embedding"] == [1.0, 2.0] + assert type(output_table.rows[0]["embedding"]) is list + + +@pytest.mark.asyncio +async def test_embed_text_partial_none_embeddings(): + """Verify rows with None embeddings are skipped in store and output.""" + rows = [ + {"id": "a", "text": "good"}, + {"id": "b", "text": "failed"}, + {"id": "c", "text": "also good"}, + ] + input_table = FakeInputTable(rows) + output_table = FakeOutputTable() + vector_store = _make_mock_vector_store() + + mixed_embeddings = [[1.0, 2.0], None, [3.0, 4.0]] + + with patch( + "graphrag.index.operations.embed_text.embed_text.run_embed_text", + new_callable=AsyncMock, + ) as mock_run: + mock_run.return_value = TextEmbeddingResult(embeddings=mixed_embeddings) + + count = await embed_text( + input_table=input_table, + callbacks=NoopWorkflowCallbacks(), + model=MagicMock(), + tokenizer=MagicMock(), + embed_column="text", + batch_size=10, + batch_max_tokens=8191, + num_threads=1, + vector_store=vector_store, + output_table=output_table, + ) + + assert count == 3 + + docs = vector_store.load_documents.call_args[0][0] + assert len(docs) == 2 + assert docs[0].id == "a" + assert docs[1].id == "c" + + assert len(output_table.rows) == 2 + assert output_table.rows[0]["id"] == "a" + assert output_table.rows[1]["id"] == "c" diff --git a/tests/verbs/test_update_text_embeddings.py b/tests/verbs/test_update_text_embeddings.py new file mode 100644 index 0000000000..f0d5e4adf0 --- /dev/null +++ b/tests/verbs/test_update_text_embeddings.py @@ -0,0 +1,65 @@ +# Copyright (C) 2026 Microsoft +# Licensed under the MIT License + +"""Verb test for the update_text_embeddings workflow.""" + +from unittest.mock import patch + +from graphrag.config.embeddings import all_embeddings +from graphrag.index.workflows.update_text_embeddings import ( + run_workflow, +) + +from tests.unit.config.utils import get_default_graphrag_config + +from .util import create_test_context + + +async def test_update_text_embeddings(): + """Verify update_text_embeddings produces embedding tables. + + Mocks get_update_table_providers to return the test context's + output_table_provider, simulating the merged tables written by + upstream update workflows. + """ + context = await create_test_context( + storage=[ + "documents", + "relationships", + "text_units", + "entities", + "community_reports", + ] + ) + context.state["update_timestamp"] = "20260220-000000" + + config = get_default_graphrag_config() + llm_settings = config.get_embedding_model_config( + config.embed_text.embedding_model_id + ) + llm_settings.type = "mock" + llm_settings.mock_responses = [1.0] * 3072 + + config.embed_text.names = list(all_embeddings) + config.snapshots.embeddings = True + + with patch( + "graphrag.index.workflows.update_text_embeddings.get_update_table_providers", + ) as mock_providers: + mock_providers.return_value = ( + context.output_table_provider, + None, + None, + ) + await run_workflow(config, context) + + parquet_files = context.output_storage.keys() + for field in all_embeddings: + assert f"embeddings.{field}.parquet" in parquet_files + + entity_embeddings = await context.output_table_provider.read_dataframe( + "embeddings.entity_description" + ) + assert len(entity_embeddings.columns) == 2 + assert "id" in entity_embeddings.columns + assert "embedding" in entity_embeddings.columns From 0537fadadb1b6ef8a7c3ec2d09ab326c61dbfc62 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Fri, 20 Feb 2026 19:04:33 -0300 Subject: [PATCH 2/2] fixes --- .../index/workflows/generate_text_embeddings.py | 4 ++-- .../indexing/operations/embed_text/test_embed_text.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index 9f5965ef36..c9deead1cb 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -88,8 +88,8 @@ async def run_workflow( async def generate_text_embeddings( config: GraphRagConfig, - table_provider: TableProvider, - cache: Cache, + table_provider: "TableProvider", + cache: "Cache", callbacks: WorkflowCallbacks, ) -> None: """Generate text embeddings for all configured fields.""" diff --git a/tests/unit/indexing/operations/embed_text/test_embed_text.py b/tests/unit/indexing/operations/embed_text/test_embed_text.py index 92ec0cfcf3..9a519a6930 100644 --- a/tests/unit/indexing/operations/embed_text/test_embed_text.py +++ b/tests/unit/indexing/operations/embed_text/test_embed_text.py @@ -328,13 +328,20 @@ async def test_embed_text_numpy_array_vectors(): output_table = FakeOutputTable() vector_store = _make_mock_vector_store() - numpy_embeddings = [np.array([1.0, 2.0]), np.array([3.0, 4.0])] + numpy_embeddings: list[list[float] | None] = [ + np.array([1.0, 2.0]).tolist(), + np.array([3.0, 4.0]).tolist(), + ] with patch( "graphrag.index.operations.embed_text.embed_text.run_embed_text", new_callable=AsyncMock, ) as mock_run: - mock_run.return_value = TextEmbeddingResult(embeddings=numpy_embeddings) + # Simulate run_embed_text returning np.ndarray objects at runtime + # by replacing the result embeddings after construction. + result = TextEmbeddingResult(embeddings=numpy_embeddings) + result.embeddings = [np.array([1.0, 2.0]), np.array([3.0, 4.0])] # type: ignore[list-item] + mock_run.return_value = result count = await embed_text( input_table=input_table,