diff --git a/packages/graphrag/graphrag/index/operations/build_noun_graph/build_noun_graph.py b/packages/graphrag/graphrag/index/operations/build_noun_graph/build_noun_graph.py index 8d3310e766..5e34aa654c 100644 --- a/packages/graphrag/graphrag/index/operations/build_noun_graph/build_noun_graph.py +++ b/packages/graphrag/graphrag/index/operations/build_noun_graph/build_noun_graph.py @@ -98,9 +98,11 @@ def _extract_edges( Input: nodes_df with schema [id, title, frequency, text_unit_ids] Returns: edges_df with schema [source, target, weight, text_unit_ids] """ + if nodes_df.empty: + return pd.DataFrame(columns=["source", "target", "weight", "text_unit_ids"]) + text_units_df = nodes_df.explode("text_unit_ids") text_units_df = text_units_df.rename(columns={"text_unit_ids": "text_unit_id"}) - text_units_df = ( text_units_df.groupby("text_unit_id") .agg({"title": lambda x: list(x) if len(x) > 1 else np.nan}) diff --git a/packages/graphrag/graphrag/index/operations/prune_graph.py b/packages/graphrag/graphrag/index/operations/prune_graph.py index b03dfb7b17..6eb6b60065 100644 --- a/packages/graphrag/graphrag/index/operations/prune_graph.py +++ b/packages/graphrag/graphrag/index/operations/prune_graph.py @@ -64,6 +64,9 @@ def prune_graph( ]) # remove edges by min weight + if len(graph.edges) == 0: + return graph + if min_edge_weight_pct > 0: min_edge_weight = np.percentile( [data[schemas.EDGE_WEIGHT] for _, _, data in graph.edges(data=True)], diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph.py b/packages/graphrag/graphrag/index/workflows/extract_graph.py index de9c454261..8ce8669109 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph.py @@ -121,14 +121,14 @@ async def extract_graph( async_type=extraction_async_type, ) - if not _validate_data(extracted_entities): - error_msg = "Entity Extraction failed. No entities detected during extraction." + if len(extracted_entities) == 0: + error_msg = "Graph Extraction failed. No entities detected during extraction." logger.error(error_msg) raise ValueError(error_msg) - if not _validate_data(extracted_relationships): + if len(extracted_relationships) == 0: error_msg = ( - "Entity Extraction failed. No relationships detected during extraction." + "Graph Extraction failed. No relationships detected during extraction." ) logger.error(error_msg) raise ValueError(error_msg) @@ -180,8 +180,3 @@ async def get_summarized_entities_relationships( extracted_entities.drop(columns=["description"], inplace=True) entities = extracted_entities.merge(entity_summaries, on="title", how="left") return entities, relationships - - -def _validate_data(df: pd.DataFrame) -> bool: - """Validate that the dataframe has data.""" - return len(df) > 0 diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py index b49f3e7a1f..fd00cbfa79 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py @@ -75,6 +75,19 @@ async def extract_graph_nlp( cache=cache, ) + if len(extracted_nodes) == 0: + error_msg = ( + "NLP Graph Extraction failed. No entities detected during extraction." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if len(extracted_edges) == 0: + error_msg = ( + "NLP Graph Extraction failed. No relationships detected during extraction." + ) + logger.error(error_msg) + # add in any other columns required by downstream workflows extracted_nodes["type"] = "NOUN PHRASE" extracted_nodes["description"] = "" diff --git a/packages/graphrag/graphrag/index/workflows/prune_graph.py b/packages/graphrag/graphrag/index/workflows/prune_graph.py index 8bb48df7ee..da4e2bc835 100644 --- a/packages/graphrag/graphrag/index/workflows/prune_graph.py +++ b/packages/graphrag/graphrag/index/workflows/prune_graph.py @@ -69,6 +69,16 @@ def prune_graph( lcc_only=pruning_config.lcc_only, ) + if len(pruned.nodes) == 0: + error_msg = "Graph Pruning failed. No entities remain." + logger.error(error_msg) + raise ValueError(error_msg) + + if len(pruned.edges) == 0: + error_msg = "Graph Pruning failed. No relationships remain." + logger.error(error_msg) + raise ValueError(error_msg) + pruned_nodes, pruned_edges = graph_to_dataframes( pruned, node_columns=["title"], edge_columns=["source", "target"] )