diff --git a/packages/graphrag/graphrag/api/index.py b/packages/graphrag/graphrag/api/index.py index 7265e46187..c15ea2e786 100644 --- a/packages/graphrag/graphrag/api/index.py +++ b/packages/graphrag/graphrag/api/index.py @@ -86,8 +86,9 @@ async def build_index( input_documents=input_documents, ): outputs.append(output) - if output.errors and len(output.errors) > 0: + if output.error is not None: logger.error("Workflow %s completed with errors", output.workflow) + workflow_callbacks.pipeline_error(output.error) else: logger.info("Workflow %s completed successfully", output.workflow) logger.debug(str(output.result)) diff --git a/packages/graphrag/graphrag/callbacks/console_workflow_callbacks.py b/packages/graphrag/graphrag/callbacks/console_workflow_callbacks.py index dbe7e0d552..547c5ed258 100644 --- a/packages/graphrag/graphrag/callbacks/console_workflow_callbacks.py +++ b/packages/graphrag/graphrag/callbacks/console_workflow_callbacks.py @@ -37,6 +37,10 @@ def workflow_end(self, name: str, instance: object) -> None: if self._verbose: print(instance) + def pipeline_error(self, error: BaseException) -> None: + """Execute this callback when an error occurs in the pipeline.""" + print(f"Pipeline error: {error}") + def progress(self, progress: Progress) -> None: """Handle when progress occurs.""" complete = progress.completed_items or 0 diff --git a/packages/graphrag/graphrag/callbacks/noop_workflow_callbacks.py b/packages/graphrag/graphrag/callbacks/noop_workflow_callbacks.py index 9f9ac2aee0..19aba39a0c 100644 --- a/packages/graphrag/graphrag/callbacks/noop_workflow_callbacks.py +++ b/packages/graphrag/graphrag/callbacks/noop_workflow_callbacks.py @@ -25,3 +25,6 @@ def workflow_end(self, name: str, instance: object) -> None: def progress(self, progress: Progress) -> None: """Handle when progress occurs.""" + + def pipeline_error(self, error: BaseException) -> None: + """Execute this callback when an error occurs in the pipeline.""" diff --git a/packages/graphrag/graphrag/callbacks/workflow_callbacks.py b/packages/graphrag/graphrag/callbacks/workflow_callbacks.py index 0429cff809..3fb09710f9 100644 --- a/packages/graphrag/graphrag/callbacks/workflow_callbacks.py +++ b/packages/graphrag/graphrag/callbacks/workflow_callbacks.py @@ -35,3 +35,7 @@ def workflow_end(self, name: str, instance: object) -> None: def progress(self, progress: Progress) -> None: """Handle when progress occurs.""" ... + + def pipeline_error(self, error: BaseException) -> None: + """Execute this callback when an error occurs in the pipeline.""" + ... diff --git a/packages/graphrag/graphrag/callbacks/workflow_callbacks_manager.py b/packages/graphrag/graphrag/callbacks/workflow_callbacks_manager.py index 1ca0c097e5..6a030ec66a 100644 --- a/packages/graphrag/graphrag/callbacks/workflow_callbacks_manager.py +++ b/packages/graphrag/graphrag/callbacks/workflow_callbacks_manager.py @@ -50,3 +50,9 @@ def progress(self, progress: Progress) -> None: for callback in self._callbacks: if hasattr(callback, "progress"): callback.progress(progress) + + def pipeline_error(self, error: BaseException) -> None: + """Execute this callback when an error occurs in the pipeline.""" + for callback in self._callbacks: + if hasattr(callback, "pipeline_error"): + callback.pipeline_error(error) diff --git a/packages/graphrag/graphrag/cli/index.py b/packages/graphrag/graphrag/cli/index.py index 0a638f63e9..d686a38f13 100644 --- a/packages/graphrag/graphrag/cli/index.py +++ b/packages/graphrag/graphrag/cli/index.py @@ -134,15 +134,6 @@ def _run_index( verbose=verbose, ) ) - encountered_errors = any( - output.errors and len(output.errors) > 0 for output in outputs - ) - - if encountered_errors: - logger.error( - "Errors occurred during the pipeline run, see logs for more details." - ) - else: - logger.info("All workflows completed successfully.") + encountered_errors = any(output.error is not None for output in outputs) sys.exit(1 if encountered_errors else 0) diff --git a/packages/graphrag/graphrag/index/operations/extract_graph/extract_graph.py b/packages/graphrag/graphrag/index/operations/extract_graph/extract_graph.py index 3aa87404ec..0d2ab1146e 100644 --- a/packages/graphrag/graphrag/index/operations/extract_graph/extract_graph.py +++ b/packages/graphrag/graphrag/index/operations/extract_graph/extract_graph.py @@ -5,17 +5,11 @@ import logging -import networkx as nx import pandas as pd from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor -from graphrag.index.operations.extract_graph.typing import ( - Document, - EntityExtractionResult, - EntityTypes, -) from graphrag.index.utils.derive_from_rows import derive_from_rows from graphrag.language_model.protocol.base import ChatModel @@ -42,14 +36,15 @@ async def run_strategy(row): text = row[text_column] id = row[id_column] result = await run_extract_graph( - [Document(text=text, id=id)], - entity_types, - model, - prompt, - max_gleanings, + text=text, + source_id=id, + entity_types=entity_types, + model=model, + prompt=prompt, + max_gleanings=max_gleanings, ) num_started += 1 - return [result.entities, result.relationships, result.graph] + return result results = await derive_from_rows( text_units, @@ -64,8 +59,8 @@ async def run_strategy(row): relationship_dfs = [] for result in results: if result: - entity_dfs.append(pd.DataFrame(result[0])) - relationship_dfs.append(pd.DataFrame(result[1])) + entity_dfs.append(result[0]) + relationship_dfs.append(result[1]) entities = _merge_entities(entity_dfs) relationships = _merge_relationships(relationship_dfs) @@ -74,12 +69,13 @@ async def run_strategy(row): async def run_extract_graph( - docs: list[Document], - entity_types: EntityTypes, + text: str, + source_id: str, + entity_types: list[str], model: ChatModel, prompt: str, max_gleanings: int, -) -> EntityExtractionResult: +) -> tuple[pd.DataFrame, pd.DataFrame]: """Run the graph intelligence entity extraction strategy.""" extractor = GraphExtractor( model=model, @@ -89,36 +85,15 @@ async def run_extract_graph( "Entity Extraction Error", exc_info=e, extra={"stack": s, "details": d} ), ) - text_list = [doc.text.strip() for doc in docs] + text = text.strip() - results = await extractor( - list(text_list), + entities_df, relationships_df = await extractor( + text, entity_types=entity_types, + source_id=source_id, ) - graph = results.output - # Map the "source_id" back to the "id" field - for _, node in graph.nodes(data=True): # type: ignore - if node is not None: - node["source_id"] = ",".join( - docs[int(id)].id for id in node["source_id"].split(",") - ) - - for _, _, edge in graph.edges(data=True): # type: ignore - if edge is not None: - edge["source_id"] = ",".join( - docs[int(id)].id for id in edge["source_id"].split(",") - ) - - entities = [ - ({"title": item[0], **(item[1] or {})}) - for item in graph.nodes(data=True) - if item is not None - ] - - relationships = nx.to_pandas_edgelist(graph) - - return EntityExtractionResult(entities, relationships, graph) + return (entities_df, relationships_df) def _merge_entities(entity_dfs) -> pd.DataFrame: diff --git a/packages/graphrag/graphrag/index/operations/extract_graph/graph_extractor.py b/packages/graphrag/graphrag/index/operations/extract_graph/graph_extractor.py index 8e98eb7ec4..6d37bf4688 100644 --- a/packages/graphrag/graphrag/index/operations/extract_graph/graph_extractor.py +++ b/packages/graphrag/graphrag/index/operations/extract_graph/graph_extractor.py @@ -1,16 +1,14 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" +"""Graph extraction helpers that return tabular data.""" import logging import re import traceback -from collections.abc import Mapping -from dataclasses import dataclass from typing import Any -import networkx as nx +import pandas as pd from graphrag.index.typing.error_handler import ErrorHandlerFn from graphrag.index.utils.string import clean_str @@ -27,24 +25,14 @@ TUPLE_DELIMITER = "<|>" RECORD_DELIMITER = "##" COMPLETION_DELIMITER = "<|COMPLETE|>" -DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] logger = logging.getLogger(__name__) -@dataclass -class GraphExtractionResult: - """Unipartite graph extraction result class definition.""" - - output: nx.Graph - source_docs: dict[Any, Any] - - class GraphExtractor: """Unipartite graph extractor class definition.""" _model: ChatModel - _join_descriptions: bool _extraction_prompt: str _max_gleanings: int _on_error: ErrorHandlerFn @@ -54,51 +42,40 @@ def __init__( model: ChatModel, prompt: str, max_gleanings: int, - join_descriptions=True, on_error: ErrorHandlerFn | None = None, ): """Init method definition.""" self._model = model - self._join_descriptions = join_descriptions self._extraction_prompt = prompt self._max_gleanings = max_gleanings self._on_error = on_error or (lambda _e, _s, _d: None) async def __call__( - self, texts: list[str], entity_types: list[str] - ) -> GraphExtractionResult: - """Call method definition.""" - all_records: dict[int, str] = {} - source_doc_map: dict[int, str] = {} - - for doc_index, text in enumerate(texts): - try: - # Invoke the entity extraction - result = await self._process_document(text, entity_types) - source_doc_map[doc_index] = text - all_records[doc_index] = result - except Exception as e: - logger.exception("error extracting graph") - self._on_error( - e, - traceback.format_exc(), - { - "doc_index": doc_index, - "text": text, - }, - ) - - output = await self._process_results( - all_records, + self, text: str, entity_types: list[str], source_id: str + ) -> tuple[pd.DataFrame, pd.DataFrame]: + """Extract entities and relationships from the supplied text.""" + try: + # Invoke the entity extraction + result = await self._process_document(text, entity_types) + except Exception as e: # pragma: no cover - defensive logging + logger.exception("error extracting graph") + self._on_error( + e, + traceback.format_exc(), + { + "source_id": source_id, + "text": text, + }, + ) + return _empty_entities_df(), _empty_relationships_df() + + return self._process_result( + result, + source_id, TUPLE_DELIMITER, RECORD_DELIMITER, ) - return GraphExtractionResult( - output=output, - source_docs=source_doc_map, - ) - async def _process_document(self, text: str, entity_types: list[str]) -> str: response = await self._model.achat( self._extraction_prompt.format(**{ @@ -133,125 +110,68 @@ async def _process_document(self, text: str, entity_types: list[str]) -> str: return results - async def _process_results( + def _process_result( self, - results: dict[int, str], + result: str, + source_id: str, tuple_delimiter: str, record_delimiter: str, - ) -> nx.Graph: - """Parse the result string to create an undirected unipartite graph. - - Args: - - results - dict of results from the extraction chain - - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' - - record_delimiter - delimiter between records, default is '##' - Returns: - - output - unipartite graph in graphML format - """ - graph = nx.Graph() - for source_doc_id, extracted_data in results.items(): - records = [r.strip() for r in extracted_data.split(record_delimiter)] - - for record in records: - record = re.sub(r"^\(|\)$", "", record.strip()) - record_attributes = record.split(tuple_delimiter) - - if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: - # add this record as a node in the G - entity_name = clean_str(record_attributes[1].upper()) - entity_type = clean_str(record_attributes[2].upper()) - entity_description = clean_str(record_attributes[3]) - - if entity_name in graph.nodes(): - node = graph.nodes[entity_name] - if self._join_descriptions: - node["description"] = "\n".join( - list({ - *_unpack_descriptions(node), - entity_description, - }) - ) - else: - if len(entity_description) > len(node["description"]): - node["description"] = entity_description - node["source_id"] = ", ".join( - list({ - *_unpack_source_ids(node), - str(source_doc_id), - }) - ) - node["type"] = ( - entity_type if entity_type != "" else node["type"] - ) - else: - graph.add_node( - entity_name, - type=entity_type, - description=entity_description, - source_id=str(source_doc_id), - ) - - if ( - record_attributes[0] == '"relationship"' - and len(record_attributes) >= 5 - ): - # add this record as edge - source = clean_str(record_attributes[1].upper()) - target = clean_str(record_attributes[2].upper()) - edge_description = clean_str(record_attributes[3]) - edge_source_id = clean_str(str(source_doc_id)) - try: - weight = float(record_attributes[-1]) - except ValueError: - weight = 1.0 - - if source not in graph.nodes(): - graph.add_node( - source, - type="", - description="", - source_id=edge_source_id, - ) - if target not in graph.nodes(): - graph.add_node( - target, - type="", - description="", - source_id=edge_source_id, - ) - if graph.has_edge(source, target): - edge_data = graph.get_edge_data(source, target) - if edge_data is not None: - weight += edge_data["weight"] - if self._join_descriptions: - edge_description = "\n".join( - list({ - *_unpack_descriptions(edge_data), - edge_description, - }) - ) - edge_source_id = ", ".join( - list({ - *_unpack_source_ids(edge_data), - str(source_doc_id), - }) - ) - graph.add_edge( - source, - target, - weight=weight, - description=edge_description, - source_id=edge_source_id, - ) + ) -> tuple[pd.DataFrame, pd.DataFrame]: + """Parse the result string into entity and relationship data frames.""" + entities: list[dict[str, Any]] = [] + relationships: list[dict[str, Any]] = [] + + records = [r.strip() for r in result.split(record_delimiter)] + + for raw_record in records: + record = re.sub(r"^\(|\)$", "", raw_record.strip()) + if not record or record == COMPLETION_DELIMITER: + continue + + record_attributes = record.split(tuple_delimiter) + record_type = record_attributes[0] + + if record_type == '"entity"' and len(record_attributes) >= 4: + entity_name = clean_str(record_attributes[1].upper()) + entity_type = clean_str(record_attributes[2].upper()) + entity_description = clean_str(record_attributes[3]) + entities.append({ + "title": entity_name, + "type": entity_type, + "description": entity_description, + "source_id": source_id, + }) + + if record_type == '"relationship"' and len(record_attributes) >= 5: + source = clean_str(record_attributes[1].upper()) + target = clean_str(record_attributes[2].upper()) + edge_description = clean_str(record_attributes[3]) + try: + weight = float(record_attributes[-1]) + except ValueError: + weight = 1.0 + + relationships.append({ + "source": source, + "target": target, + "description": edge_description, + "source_id": source_id, + "weight": weight, + }) + + entities_df = pd.DataFrame(entities) if entities else _empty_entities_df() + relationships_df = ( + pd.DataFrame(relationships) if relationships else _empty_relationships_df() + ) - return graph + return entities_df, relationships_df -def _unpack_descriptions(data: Mapping) -> list[str]: - value = data.get("description", None) - return [] if value is None else value.split("\n") +def _empty_entities_df() -> pd.DataFrame: + return pd.DataFrame(columns=["title", "type", "description", "source_id"]) -def _unpack_source_ids(data: Mapping) -> list[str]: - value = data.get("source_id", None) - return [] if value is None else value.split(", ") +def _empty_relationships_df() -> pd.DataFrame: + return pd.DataFrame( + columns=["source", "target", "weight", "description", "source_id"] + ) diff --git a/packages/graphrag/graphrag/index/operations/extract_graph/typing.py b/packages/graphrag/graphrag/index/operations/extract_graph/typing.py deleted file mode 100644 index d74eb9a476..0000000000 --- a/packages/graphrag/graphrag/index/operations/extract_graph/typing.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing 'Document' and 'EntityExtractionResult' models.""" - -from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from typing import Any - -import networkx as nx - -from graphrag.cache.pipeline_cache import PipelineCache - -ExtractedEntity = dict[str, Any] -ExtractedRelationship = dict[str, Any] -StrategyConfig = dict[str, Any] -EntityTypes = list[str] - - -@dataclass -class Document: - """Document class definition.""" - - text: str - id: str - - -@dataclass -class EntityExtractionResult: - """Entity extraction result class definition.""" - - entities: list[ExtractedEntity] - relationships: list[ExtractedRelationship] - graph: nx.Graph | None - - -EntityExtractStrategy = Callable[ - [ - list[Document], - EntityTypes, - PipelineCache, - StrategyConfig, - ], - Awaitable[EntityExtractionResult], -] diff --git a/packages/graphrag/graphrag/index/operations/summarize_communities/typing.py b/packages/graphrag/graphrag/index/operations/summarize_communities/typing.py index 709c5ccc6a..e59c4f33b1 100644 --- a/packages/graphrag/graphrag/index/operations/summarize_communities/typing.py +++ b/packages/graphrag/graphrag/index/operations/summarize_communities/typing.py @@ -10,9 +10,7 @@ from graphrag.language_model.protocol.base import ChatModel -ExtractedEntity = dict[str, Any] RowContext = dict[str, Any] -EntityTypes = list[str] Claim = dict[str, Any] diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index a0b2011eab..a373e43b76 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -119,7 +119,7 @@ async def _run_pipeline( result = await workflow_function(config, context) context.callbacks.workflow_end(name, result) yield PipelineRunResult( - workflow=name, result=result.result, state=context.state, errors=None + workflow=name, result=result.result, state=context.state, error=None ) context.stats.workflows[name] = {"overall": time.time() - work_time} if result.stop: @@ -133,7 +133,7 @@ async def _run_pipeline( except Exception as e: logger.exception("error running workflow %s", last_workflow) yield PipelineRunResult( - workflow=last_workflow, result=None, state=context.state, errors=[e] + workflow=last_workflow, result=None, state=context.state, error=e ) diff --git a/packages/graphrag/graphrag/index/typing/pipeline_run_result.py b/packages/graphrag/graphrag/index/typing/pipeline_run_result.py index f6a68d82a0..b39e030268 100644 --- a/packages/graphrag/graphrag/index/typing/pipeline_run_result.py +++ b/packages/graphrag/graphrag/index/typing/pipeline_run_result.py @@ -19,4 +19,4 @@ class PipelineRunResult: """The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage.""" state: PipelineState """Ongoing pipeline context state object.""" - errors: list[BaseException] | None + error: BaseException | None diff --git a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py index c27d905c70..0751adb588 100644 --- a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py +++ b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py @@ -3,206 +3,76 @@ import unittest from graphrag.index.operations.extract_graph.extract_graph import run_extract_graph -from graphrag.index.operations.extract_graph.typing import ( - Document, -) from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT from tests.unit.indexing.verbs.helpers.mock_llm import create_mock_llm +SIMPLE_EXTRACTION_RESPONSE = """ +("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) +## +("entity"<|>TEST_ENTITY_2<|>COMPANY<|>TEST_ENTITY_2 owns TEST_ENTITY_1 and also shares an address with TEST_ENTITY_1) +## +("entity"<|>TEST_ENTITY_3<|>PERSON<|>TEST_ENTITY_3 is director of TEST_ENTITY_1) +## +("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_2<|>TEST_ENTITY_1 and TEST_ENTITY_2 are related because TEST_ENTITY_1 is 100% owned by TEST_ENTITY_2 and the two companies also share the same address)<|>2) +## +("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_3<|>TEST_ENTITY_1 and TEST_ENTITY_3 are related because TEST_ENTITY_3 is director of TEST_ENTITY_1<|>1)) +""".strip() + class TestRunChain(unittest.IsolatedAsyncioTestCase): async def test_run_extract_graph_single_document_correct_entities_returned(self): - results = await run_extract_graph( - docs=[Document("test_text", "1")], + entities_df, _ = await run_extract_graph( + text="test_text", + source_id="1", entity_types=["person"], max_gleanings=0, model=create_mock_llm( - responses=[ - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_2<|>COMPANY<|>TEST_ENTITY_2 owns TEST_ENTITY_1 and also shares an address with TEST_ENTITY_1) - ## - ("entity"<|>TEST_ENTITY_3<|>PERSON<|>TEST_ENTITY_3 is director of TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_2<|>TEST_ENTITY_1 and TEST_ENTITY_2 are related because TEST_ENTITY_1 is 100% owned by TEST_ENTITY_2 and the two companies also share the same address)<|>2) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_3<|>TEST_ENTITY_1 and TEST_ENTITY_3 are related because TEST_ENTITY_3 is director of TEST_ENTITY_1<|>1)) - """.strip() - ], + responses=[SIMPLE_EXTRACTION_RESPONSE], name="test_run_extract_graph_single_document_correct_entities_returned", ), prompt=GRAPH_EXTRACTION_PROMPT, ) - # self.assertItemsEqual isn't available yet, or I am just silly - # so we sort the lists and compare them - assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ - entity["title"] for entity in results.entities - ]) - - async def test_run_extract_graph_multiple_documents_correct_entities_returned( - self, - ): - results = await run_extract_graph( - docs=[Document("text_1", "1"), Document("text_2", "2")], - entity_types=["person"], - max_gleanings=0, - model=create_mock_llm( - responses=[ - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_2<|>COMPANY<|>TEST_ENTITY_2 owns TEST_ENTITY_1 and also shares an address with TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_2<|>TEST_ENTITY_1 and TEST_ENTITY_2 are related because TEST_ENTITY_1 is 100% owned by TEST_ENTITY_2 and the two companies also share the same address)<|>2) - ## - """.strip(), - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_3<|>PERSON<|>TEST_ENTITY_3 is director of TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_3<|>TEST_ENTITY_1 and TEST_ENTITY_3 are related because TEST_ENTITY_3 is director of TEST_ENTITY_1<|>1)) - """.strip(), - ], - name="test_run_extract_graph_multiple_documents_correct_entities_returned", - ), - prompt=GRAPH_EXTRACTION_PROMPT, - ) - - # self.assertItemsEqual isn't available yet, or I am just silly - # so we sort the lists and compare them - assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ - entity["title"] for entity in results.entities - ]) - - async def test_run_extract_graph_multiple_documents_correct_edges_returned(self): - results = await run_extract_graph( - docs=[Document("text_1", "1"), Document("text_2", "2")], - entity_types=["person"], - max_gleanings=0, - model=create_mock_llm( - responses=[ - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_2<|>COMPANY<|>TEST_ENTITY_2 owns TEST_ENTITY_1 and also shares an address with TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_2<|>TEST_ENTITY_1 and TEST_ENTITY_2 are related because TEST_ENTITY_1 is 100% owned by TEST_ENTITY_2 and the two companies also share the same address)<|>2) - ## - """.strip(), - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_3<|>PERSON<|>TEST_ENTITY_3 is director of TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_3<|>TEST_ENTITY_1 and TEST_ENTITY_3 are related because TEST_ENTITY_3 is director of TEST_ENTITY_1<|>1)) - """.strip(), - ], - name="test_run_extract_graph_multiple_documents_correct_edges_returned", - ), - prompt=GRAPH_EXTRACTION_PROMPT, + assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted( + entities_df["title"].tolist() ) - # self.assertItemsEqual isn't available yet, or I am just silly - # so we sort the lists and compare them - graph = results.graph - assert graph is not None, "No graph returned!" - - # convert to strings for more visual comparison - edges_str = sorted([f"{edge[0]} -> {edge[1]}" for edge in graph.edges]) - assert edges_str == sorted([ - "TEST_ENTITY_1 -> TEST_ENTITY_2", - "TEST_ENTITY_1 -> TEST_ENTITY_3", - ]) - - async def test_run_extract_graph_multiple_documents_correct_entity_source_ids_mapped( - self, - ): - results = await run_extract_graph( - docs=[Document("text_1", "1"), Document("text_2", "2")], + async def test_run_extract_graph_single_document_correct_edges_returned(self): + _, relationships_df = await run_extract_graph( + text="test_text", + source_id="1", entity_types=["person"], max_gleanings=0, model=create_mock_llm( - responses=[ - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_2<|>COMPANY<|>TEST_ENTITY_2 owns TEST_ENTITY_1 and also shares an address with TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_2<|>TEST_ENTITY_1 and TEST_ENTITY_2 are related because TEST_ENTITY_1 is 100% owned by TEST_ENTITY_2 and the two companies also share the same address)<|>2) - ## - """.strip(), - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_3<|>PERSON<|>TEST_ENTITY_3 is director of TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_3<|>TEST_ENTITY_1 and TEST_ENTITY_3 are related because TEST_ENTITY_3 is director of TEST_ENTITY_1<|>1)) - """.strip(), - ], - name="test_run_extract_graph_multiple_documents_correct_entity_source_ids_mapped", + responses=[SIMPLE_EXTRACTION_RESPONSE], + name="test_run_extract_graph_single_document_correct_edges_returned", ), prompt=GRAPH_EXTRACTION_PROMPT, ) - graph = results.graph - assert graph is not None, "No graph returned!" + edges = relationships_df.to_dict("records") + assert len(edges) == 2 - # TODO: The edges might come back in any order, but we're assuming they're coming - # back in the order that we passed in the docs, that might not be true - assert ( - graph.nodes["TEST_ENTITY_3"].get("source_id") == "2" - ) # TEST_ENTITY_3 should be in just 2 - assert ( - graph.nodes["TEST_ENTITY_2"].get("source_id") == "1" - ) # TEST_ENTITY_2 should be in just 1 - ids_str = graph.nodes["TEST_ENTITY_1"].get("source_id") or "" - assert sorted(ids_str.split(",")) == sorted([ - "1", - "2", - ]) # TEST_ENTITY_1 should be 1 and 2 + relationship_pairs = {(edge["source"], edge["target"]) for edge in edges} + assert relationship_pairs == { + ("TEST_ENTITY_1", "TEST_ENTITY_2"), + ("TEST_ENTITY_1", "TEST_ENTITY_3"), + } - async def test_run_extract_graph_multiple_documents_correct_edge_source_ids_mapped( - self, - ): - results = await run_extract_graph( - docs=[Document("text_1", "1"), Document("text_2", "2")], + async def test_run_extract_graph_single_document_source_ids_mapped(self): + entities_df, relationships_df = await run_extract_graph( + text="test_text", + source_id="1", entity_types=["person"], max_gleanings=0, model=create_mock_llm( - responses=[ - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_2<|>COMPANY<|>TEST_ENTITY_2 owns TEST_ENTITY_1 and also shares an address with TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_2<|>TEST_ENTITY_1 and TEST_ENTITY_2 are related because TEST_ENTITY_1 is 100% owned by TEST_ENTITY_2 and the two companies also share the same address)<|>2) - ## - """.strip(), - """ - ("entity"<|>TEST_ENTITY_1<|>COMPANY<|>TEST_ENTITY_1 is a test company) - ## - ("entity"<|>TEST_ENTITY_3<|>PERSON<|>TEST_ENTITY_3 is director of TEST_ENTITY_1) - ## - ("relationship"<|>TEST_ENTITY_1<|>TEST_ENTITY_3<|>TEST_ENTITY_1 and TEST_ENTITY_3 are related because TEST_ENTITY_3 is director of TEST_ENTITY_1<|>1)) - """.strip(), - ], - name="test_run_extract_graph_multiple_documents_correct_edge_source_ids_mapped", + responses=[SIMPLE_EXTRACTION_RESPONSE], + name="test_run_extract_graph_single_document_source_ids_mapped", ), prompt=GRAPH_EXTRACTION_PROMPT, ) - graph = results.graph - assert graph is not None, "No graph returned!" - edges = list(graph.edges(data=True)) - - # should only have 2 edges - assert len(edges) == 2 + assert all(source_id == "1" for source_id in entities_df["source_id"]) - # Sort by source_id for consistent ordering - edge_source_ids = sorted([edge[2].get("source_id", "") for edge in edges]) - assert edge_source_ids[0].split(",") == ["1"] - assert edge_source_ids[1].split(",") == ["2"] + assert all(source_id == "1" for source_id in relationships_df["source_id"])