diff --git a/docs/config/yaml.md b/docs/config/yaml.md index ace57e3b1c..76c177b60f 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -182,14 +182,9 @@ Where to put all vectors for the system. Configured for lancedb by default. This The supported embeddings are: -- `text_unit.text` -- `document.text` -- `entity.title` -- `entity.description` -- `relationship.description` -- `community.title` -- `community.summary` -- `community.full_content` +- `text_unit_text` +- `entity_description` +- `community_full_content` For example: @@ -199,12 +194,12 @@ vector_store: db_uri: output/lancedb index_prefix: "christmas-carol" embeddings_schema: - text_unit.text: + text_unit_text: index_name: "text-unit-embeddings" id_field: "id_custom" vector_field: "vector_custom" vector_size: 3072 - entity.description: + entity_description: id_field: "id_custom" ``` @@ -224,14 +219,9 @@ By default, the GraphRAG indexer will only export embeddings required for our qu Supported embeddings names are: -- `text_unit.text` -- `document.text` -- `entity.title` -- `entity.description` -- `relationship.description` -- `community.title` -- `community.summary` -- `community.full_content` +- `text_unit_text` +- `entity_description` +- `community_full_content` #### Fields diff --git a/docs/examples_notebooks/api_overview.ipynb b/docs/examples_notebooks/api_overview.ipynb index 06187a5771..abcd7832fc 100644 --- a/docs/examples_notebooks/api_overview.ipynb +++ b/docs/examples_notebooks/api_overview.ipynb @@ -28,10 +28,11 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", - "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", + "\n", + "import graphrag.api as api" ] }, { @@ -170,7 +171,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "graphrag-monorepo", "language": "python", "name": "python3" }, @@ -184,7 +185,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/docs/examples_notebooks/index_migration_to_v1.ipynb b/docs/examples_notebooks/index_migration_to_v1.ipynb index 4a89d95305..68a06358e6 100644 --- a/docs/examples_notebooks/index_migration_to_v1.ipynb +++ b/docs/examples_notebooks/index_migration_to_v1.ipynb @@ -229,8 +229,6 @@ "tokenizer = get_tokenizer(model_config)\n", "\n", "await generate_text_embeddings(\n", - " documents=None,\n", - " relationships=None,\n", " text_units=final_text_units,\n", " entities=final_entities,\n", " community_reports=final_community_reports,\n", diff --git a/docs/examples_notebooks/input_documents.ipynb b/docs/examples_notebooks/input_documents.ipynb index b9af6075ab..505c0fe1f3 100644 --- a/docs/examples_notebooks/input_documents.ipynb +++ b/docs/examples_notebooks/input_documents.ipynb @@ -30,10 +30,11 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", - "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", + "\n", + "import graphrag.api as api" ] }, { @@ -171,7 +172,7 @@ ], "metadata": { "kernelspec": { - "display_name": "graphrag", + "display_name": "graphrag-monorepo", "language": "python", "name": "python3" }, @@ -185,7 +186,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.10" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/packages/graphrag/graphrag/config/embeddings.py b/packages/graphrag/graphrag/config/embeddings.py index 60711b8aa6..c06567acfa 100644 --- a/packages/graphrag/graphrag/config/embeddings.py +++ b/packages/graphrag/graphrag/config/embeddings.py @@ -3,22 +3,12 @@ """A module containing embeddings values.""" -entity_title_embedding = "entity.title" -entity_description_embedding = "entity.description" -relationship_description_embedding = "relationship.description" -document_text_embedding = "document.text" -community_title_embedding = "community.title" -community_summary_embedding = "community.summary" -community_full_content_embedding = "community.full_content" -text_unit_text_embedding = "text_unit.text" +entity_description_embedding = "entity_description" +community_full_content_embedding = "community_full_content" +text_unit_text_embedding = "text_unit_text" all_embeddings: set[str] = { - entity_title_embedding, entity_description_embedding, - relationship_description_embedding, - document_text_embedding, - community_title_embedding, - community_summary_embedding, community_full_content_embedding, text_unit_text_embedding, } @@ -47,5 +37,5 @@ def create_index_name( raise KeyError(msg) if index_prefix: - return f"{index_prefix}-{embedding_name}".replace(".", "-") - return embedding_name.replace(".", "-") + return f"{index_prefix}-{embedding_name}" + return embedding_name diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index 6e474437f5..ef24d3e348 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -10,13 +10,8 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.embeddings import ( community_full_content_embedding, - community_summary_embedding, - community_title_embedding, create_index_name, - document_text_embedding, entity_description_embedding, - entity_title_embedding, - relationship_description_embedding, text_unit_text_embedding, ) from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -47,29 +42,14 @@ async def run_workflow( logger.info("Workflow started: generate_text_embeddings") embedded_fields = config.embed_text.names logger.info("Embedding the following fields: %s", embedded_fields) - documents = None - relationships = None text_units = None entities = None community_reports = None - if document_text_embedding in embedded_fields: - documents = await load_table_from_storage("documents", context.output_storage) - if relationship_description_embedding in embedded_fields: - relationships = await load_table_from_storage( - "relationships", context.output_storage - ) if text_unit_text_embedding in embedded_fields: text_units = await load_table_from_storage("text_units", context.output_storage) - if ( - entity_title_embedding in embedded_fields - or entity_description_embedding in embedded_fields - ): + if entity_description_embedding in embedded_fields: entities = await load_table_from_storage("entities", context.output_storage) - if ( - community_title_embedding in embedded_fields - or community_summary_embedding in embedded_fields - or community_full_content_embedding in embedded_fields - ): + if community_full_content_embedding in embedded_fields: community_reports = await load_table_from_storage( "community_reports", context.output_storage ) @@ -87,8 +67,6 @@ async def run_workflow( tokenizer = get_tokenizer(model_config) output = await generate_text_embeddings( - documents=documents, - relationships=relationships, text_units=text_units, entities=entities, community_reports=community_reports, @@ -115,8 +93,6 @@ async def run_workflow( async def generate_text_embeddings( - documents: pd.DataFrame | None, - relationships: pd.DataFrame | None, text_units: pd.DataFrame | None, entities: pd.DataFrame | None, community_reports: pd.DataFrame | None, @@ -131,26 +107,12 @@ async def generate_text_embeddings( ) -> dict[str, pd.DataFrame]: """All the steps to generate all embeddings.""" embedding_param_map = { - document_text_embedding: { - "data": documents.loc[:, ["id", "text"]] if documents is not None else None, - "embed_column": "text", - }, - relationship_description_embedding: { - "data": relationships.loc[:, ["id", "description"]] - if relationships is not None - else None, - "embed_column": "description", - }, text_unit_text_embedding: { "data": text_units.loc[:, ["id", "text"]] if text_units is not None else None, "embed_column": "text", }, - entity_title_embedding: { - "data": entities.loc[:, ["id", "title"]] if entities is not None else None, - "embed_column": "title", - }, entity_description_embedding: { "data": entities.loc[:, ["id", "title", "description"]].assign( title_description=lambda df: df["title"] + ":" + df["description"] @@ -159,18 +121,6 @@ async def generate_text_embeddings( else None, "embed_column": "title_description", }, - community_title_embedding: { - "data": community_reports.loc[:, ["id", "title"]] - if community_reports is not None - else None, - "embed_column": "title", - }, - community_summary_embedding: { - "data": community_reports.loc[:, ["id", "summary"]] - if community_reports is not None - else None, - "embed_column": "summary", - }, community_full_content_embedding: { "data": community_reports.loc[:, ["id", "full_content"]] if community_reports is not None diff --git a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py index 0fc5f1fdeb..4c6c280650 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py @@ -26,9 +26,6 @@ async def run_workflow( output_storage, _, _ = get_update_storages( config, context.state["update_timestamp"] ) - - final_documents_df = context.state["incremental_update_final_documents"] - merged_relationships_df = context.state["incremental_update_merged_relationships"] merged_text_units = context.state["incremental_update_merged_text_units"] merged_entities_df = context.state["incremental_update_merged_entities"] merged_community_reports = context.state[ @@ -50,8 +47,6 @@ async def run_workflow( tokenizer = get_tokenizer(model_config) result = await generate_text_embeddings( - documents=final_documents_df, - relationships=merged_relationships_df, text_units=merged_text_units, entities=merged_entities_df, community_reports=merged_community_reports, diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index ca2bd7d823..c28508be8b 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -85,9 +85,9 @@ ], "max_runtime": 150, "expected_artifacts": [ - "embeddings.text_unit.text.parquet", - "embeddings.entity.description.parquet", - "embeddings.community.full_content.parquet" + "embeddings.text_unit_text.parquet", + "embeddings.entity_description.parquet", + "embeddings.community_full_content.parquet" ] } }, diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index eef957da2f..792b91c48c 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -84,9 +84,9 @@ ], "max_runtime": 150, "expected_artifacts": [ - "embeddings.text_unit.text.parquet", - "embeddings.entity.description.parquet", - "embeddings.community.full_content.parquet" + "embeddings.text_unit_text.parquet", + "embeddings.entity_description.parquet", + "embeddings.community_full_content.parquet" ] } }, diff --git a/tests/unit/utils/test_embeddings.py b/tests/unit/utils/test_embeddings.py index 343ee812b5..63d0619c1d 100644 --- a/tests/unit/utils/test_embeddings.py +++ b/tests/unit/utils/test_embeddings.py @@ -6,8 +6,8 @@ def test_create_index_name(): - collection = create_index_name("default", "entity.title") - assert collection == "default-entity-title" + collection = create_index_name("default", "entity_description") + assert collection == "default-entity_description" def test_create_index_name_invalid_embedding_throws(): @@ -16,5 +16,5 @@ def test_create_index_name_invalid_embedding_throws(): def test_create_index_name_invalid_embedding_does_not_throw(): - collection = create_index_name("default", "invalid.name", validate=False) - assert collection == "default-invalid-name" + collection = create_index_name("default", "invalid_name", validate=False) + assert collection == "default-invalid_name" diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index d479120ce2..561f54108b 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -4,15 +4,16 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS -from graphrag.index.operations.summarize_communities.community_reports_extractor import ( - CommunityReportResponse, - FindingModel, -) from graphrag.index.workflows.create_community_reports import ( run_workflow, ) from graphrag.utils.storage import load_table_from_storage +from graphrag.index.operations.summarize_communities.community_reports_extractor import ( + CommunityReportResponse, + FindingModel, +) + from .util import ( DEFAULT_MODEL_CONFIG, compare_outputs, diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index e9e6c00afa..f65e0c642c 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -44,18 +44,9 @@ async def test_generate_text_embeddings(): # entity description should always be here, let's assert its format entity_description_embeddings = await load_table_from_storage( - "embeddings.entity.description", context.output_storage + "embeddings.entity_description", context.output_storage ) assert len(entity_description_embeddings.columns) == 2 assert "id" in entity_description_embeddings.columns assert "embedding" in entity_description_embeddings.columns - - # every other embedding is optional but we've turned them all on, so check a random one - document_text_embeddings = await load_table_from_storage( - "embeddings.document.text", context.output_storage - ) - - assert len(document_text_embeddings.columns) == 2 - assert "id" in document_text_embeddings.columns - assert "embedding" in document_text_embeddings.columns diff --git a/unified-search-app/app/app_logic.py b/unified-search-app/app/app_logic.py index dc64e0e77c..a573b9daa5 100644 --- a/unified-search-app/app/app_logic.py +++ b/unified-search-app/app/app_logic.py @@ -7,7 +7,6 @@ import logging from typing import TYPE_CHECKING -import graphrag.api as api import streamlit as st from knowledge_loader.data_sources.loader import ( create_datasource, @@ -18,6 +17,8 @@ from state.session_variables import SessionVariables from ui.search import display_search_result +import graphrag.api as api + if TYPE_CHECKING: import pandas as pd