Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250422210800599071.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Brings parity with our latest NLP extraction approaches."
}
2 changes: 1 addition & 1 deletion docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ Parameters for manual graph pruning. This can be used to optimize the modularity
- max_node_freq_std **float | None** - The maximum standard deviation of node frequency to allow.
- min_node_degree **int** - The minimum node degree to allow.
- max_node_degree_std **float | None** - The maximum standard deviation of node degree to allow.
- min_edge_weight_pct **int** - The minimum edge weight percentile to allow.
- min_edge_weight_pct **float** - The minimum edge weight percentile to allow.
- remove_ego_nodes **bool** - Remove ego nodes.
- lcc_only **bool** - Only use largest connected component.

Expand Down
9 changes: 6 additions & 3 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
ReportingType,
TextEmbeddingTarget,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
)
from graphrag.vector_stores.factory import VectorStoreType

DEFAULT_OUTPUT_BASE_DIR = "output"
Expand Down Expand Up @@ -186,7 +189,7 @@ class TextAnalyzerDefaults:
max_word_length: int = 15
word_delimiter: str = " "
include_named_entities: bool = True
exclude_nouns: None = None
exclude_nouns: list[str] = field(default_factory=lambda: EN_STOP_WORDS)
exclude_entity_tags: list[str] = field(default_factory=lambda: ["DATE"])
exclude_pos_tags: list[str] = field(
default_factory=lambda: ["DET", "PRON", "INTJ", "X"]
Expand Down Expand Up @@ -317,8 +320,8 @@ class PruneGraphDefaults:
max_node_freq_std: None = None
min_node_degree: int = 1
max_node_degree_std: None = None
min_edge_weight_pct: int = 40
remove_ego_nodes: bool = False
min_edge_weight_pct: float = 40.0
remove_ego_nodes: bool = True
lcc_only: bool = False


Expand Down
65 changes: 22 additions & 43 deletions graphrag/index/operations/build_noun_graph/build_noun_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

"""Graph extraction using NLP."""

import math
from itertools import combinations

import numpy as np
import pandas as pd

from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
Expand All @@ -30,7 +31,6 @@ async def build_noun_graph(
text_units, text_analyzer, num_threads=num_threads, cache=cache
)
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)

return (nodes_df, edges_df)


Expand Down Expand Up @@ -69,7 +69,7 @@ async def extract(row):
noun_node_df = text_unit_df.explode("noun_phrases")
noun_node_df = noun_node_df.rename(
columns={"noun_phrases": "title", "id": "text_unit_id"}
).drop_duplicates()
)

# group by title and count the number of text units
grouped_node_df = (
Expand All @@ -94,64 +94,44 @@ def _extract_edges(
"""
text_units_df = nodes_df.explode("text_unit_ids")
text_units_df = text_units_df.rename(columns={"text_unit_ids": "text_unit_id"})

text_units_df = (
text_units_df.groupby("text_unit_id").agg({"title": list}).reset_index()
text_units_df.groupby("text_unit_id")
.agg({"title": lambda x: list(x) if len(x) > 1 else np.nan})
.reset_index()
)
text_units_df["edges"] = text_units_df["title"].apply(
lambda x: _create_relationships(x)
text_units_df = text_units_df.dropna()
titles = text_units_df["title"].tolist()
all_edges: list[list[tuple[str, str]]] = [list(combinations(t, 2)) for t in titles]

text_units_df = text_units_df.assign(edges=all_edges) # type: ignore
edge_df = text_units_df.explode("edges")[["edges", "text_unit_id"]]

edge_df[["source", "target"]] = edge_df.loc[:, "edges"].to_list()
edge_df["min_source"] = edge_df[["source", "target"]].min(axis=1)
edge_df["max_target"] = edge_df[["source", "target"]].max(axis=1)
edge_df = edge_df.drop(columns=["source", "target"]).rename(
columns={"min_source": "source", "max_target": "target"} # type: ignore
)
edge_df = text_units_df.explode("edges").loc[:, ["edges", "text_unit_id"]]

edge_df["source"] = edge_df["edges"].apply(
lambda x: x[0] if isinstance(x, tuple) else None
)
edge_df["target"] = edge_df["edges"].apply(
lambda x: x[1] if isinstance(x, tuple) else None
)
edge_df = edge_df[(edge_df.source.notna()) & (edge_df.target.notna())]
edge_df = edge_df.drop(columns=["edges"])

# make sure source is always smaller than target
edge_df["source"], edge_df["target"] = zip(
*edge_df.apply(
lambda x: (x["source"], x["target"])
if x["source"] < x["target"]
else (x["target"], x["source"]),
axis=1,
),
strict=False,
)

# group by source and target, count the number of text units and collect their ids
# group by source and target, count the number of text units
grouped_edge_df = (
edge_df.groupby(["source", "target"]).agg({"text_unit_id": list}).reset_index()
)
grouped_edge_df = grouped_edge_df.rename(columns={"text_unit_id": "text_unit_ids"})
grouped_edge_df["weight"] = grouped_edge_df["text_unit_ids"].apply(len)

grouped_edge_df = grouped_edge_df.loc[
:, ["source", "target", "weight", "text_unit_ids"]
]

if normalize_edge_weights:
# use PMI weight instead of raw weight
grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df)

return grouped_edge_df


def _create_relationships(
noun_phrases: list[str],
) -> list[tuple[str, str]]:
"""Create a (source, target) tuple pairwise for all noun phrases in a list."""
relationships = []
if len(noun_phrases) >= 2:
for i in range(len(noun_phrases) - 1):
for j in range(i + 1, len(noun_phrases)):
relationships.extend([(noun_phrases[i], noun_phrases[j])])
return relationships


def _calculate_pmi_edge_weights(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
Expand Down Expand Up @@ -192,8 +172,7 @@ def _calculate_pmi_edge_weights(
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "target_prop"})
)
edges_df[edge_weight_col] = edges_df.apply(
lambda x: math.log2(x["prop_weight"] / (x["source_prop"] * x["target_prop"])),
axis=1,
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
)
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])
8 changes: 8 additions & 0 deletions graphrag/index/operations/graph_to_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ def graph_to_dataframes(

edges = nx.to_pandas_edgelist(graph)

# we don't deal in directed graphs, but we do need to ensure consistent ordering for df joins
# nx loses the initial ordering
edges["min_source"] = edges[["source", "target"]].min(axis=1)
edges["max_target"] = edges[["source", "target"]].max(axis=1)
edges = edges.drop(columns=["source", "target"]).rename(
columns={"min_source": "source", "max_target": "target"} # type: ignore
)

if node_columns:
nodes = nodes.loc[:, node_columns]

Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/operations/prune_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def prune_graph(
max_node_freq_std: float | None = None,
min_node_degree: int = 1,
max_node_degree_std: float | None = None,
min_edge_weight_pct: float = 0,
min_edge_weight_pct: float = 40,
remove_ego_nodes: bool = False,
lcc_only: bool = False,
) -> nx.Graph:
Expand Down
3 changes: 2 additions & 1 deletion graphrag/index/workflows/finalize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ async def run_workflow(

if config.snapshots.graphml:
# todo: extract graphs at each level, and add in meta like descriptions
graph = create_graph(relationships)
graph = create_graph(final_relationships, edge_attr=["weight"])

await snapshot_graphml(
graph,
name="graph",
Expand Down
2 changes: 1 addition & 1 deletion tests/verbs/test_prune_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ async def test_prune_graph():

nodes_actual = await load_table_from_storage("entities", context.storage)

assert len(nodes_actual) == 21
assert len(nodes_actual) == 20
Loading