Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250515212234042330.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "A few fixes and enhancements for better reuse and flow."
}
3 changes: 1 addition & 2 deletions docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ Supported embeddings names are:
- `vector_store_id` **str** - Name of vector store definition to write to.
- `batch_size` **int** - The maximum batch size to use.
- `batch_max_tokens` **int** - The maximum batch # of tokens.
- `target` **required|all|selected|none** - Determines which set of embeddings to export.
- `names` **list[str]** - If target=selected, this should be an explicit list of the embeddings names we support.
- `names` **list[str]** - List of the embeddings names to run (must be in supported list).

### extract_graph

Expand Down
5 changes: 2 additions & 3 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Literal

from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import (
AsyncType,
AuthType,
Expand All @@ -18,7 +19,6 @@
NounPhraseExtractorType,
OutputType,
ReportingType,
TextEmbeddingTarget,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
Expand Down Expand Up @@ -147,9 +147,8 @@ class EmbedTextDefaults:
model: str = "text-embedding-3-small"
batch_size: int = 16
batch_max_tokens: int = 8191
target = TextEmbeddingTarget.required
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
names: list[str] = field(default_factory=list)
names: list[str] = field(default_factory=lambda: default_embeddings)
strategy: None = None
vector_store_id: str = DEFAULT_VECTOR_STORE_ID

Expand Down
56 changes: 2 additions & 54 deletions graphrag/config/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

"""A module containing embeddings values."""

from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.graph_rag_config import GraphRagConfig

entity_title_embedding = "entity.title"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
Expand All @@ -25,60 +22,11 @@
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {
default_embeddings: list[str] = [
entity_description_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}


def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
"""Get the fields to embed based on the enum or specifically selected embeddings."""
match settings.embed_text.target:
case TextEmbeddingTarget.all:
return all_embeddings
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.selected:
return set(settings.embed_text.names)
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embed_text.target}"
raise ValueError(msg)


def get_embedding_settings(
settings: GraphRagConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)
vector_store_settings = settings.get_vector_store_config(
settings.embed_text.vector_store_id
).model_dump()

#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.embed_text.resolved_strategy(
embeddings_llm_settings
) # get the default strategy
strategy.update({
"vector_store": {
**(vector_store_params or {}),
**(vector_store_settings),
}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}
]


def create_collection_name(
Expand Down
25 changes: 12 additions & 13 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,6 @@ def __repr__(self):
return f'"{self.value}"'


class TextEmbeddingTarget(str, Enum):
"""The target to use for text embeddings."""

all = "all"
required = "required"
selected = "selected"
none = "none"

def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'


class ModelType(str, Enum):
"""LLMType enum class definition."""

Expand Down Expand Up @@ -176,3 +163,15 @@ class NounPhraseExtractorType(str, Enum):
"""Noun phrase extractor based on dependency parsing and NER using SpaCy."""
CFG = "cfg"
"""Noun phrase extractor combining CFG-based noun-chunk extraction and NER."""


class ModularityMetric(str, Enum):
"""Enum for the modularity metric to use."""

Graph = "graph"
"""Graph modularity metric."""

LCC = "lcc"

WeightedComponents = "weighted_components"
"""Weighted components modularity metric."""
39 changes: 39 additions & 0 deletions graphrag/config/get_embedding_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing get_embedding_settings."""

from graphrag.config.models.graph_rag_config import GraphRagConfig


def get_embedding_settings(
settings: GraphRagConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)
vector_store_settings = settings.get_vector_store_config(
settings.embed_text.vector_store_id
).model_dump()

#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.embed_text.resolved_strategy(
embeddings_llm_settings
) # get the default strategy
strategy.update({
"vector_store": {
**(vector_store_params or {}),
**(vector_store_settings),
}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}
5 changes: 0 additions & 5 deletions graphrag/config/models/text_embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pydantic import BaseModel, Field

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.language_model_config import LanguageModelConfig


Expand All @@ -29,10 +28,6 @@ class TextEmbeddingConfig(BaseModel):
description="The batch max tokens to use.",
default=graphrag_config_defaults.embed_text.batch_max_tokens,
)
target: TextEmbeddingTarget = Field(
description="The target to use. 'all', 'required', 'selected', or 'none'.",
default=graphrag_config_defaults.embed_text.target,
)
names: list[str] = Field(
description="The specific embeddings to perform.",
default=graphrag_config_defaults.embed_text.names,
Expand Down
5 changes: 5 additions & 0 deletions graphrag/data_model/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Community(Named):
relationship_ids: list[str] | None = None
"""List of relationship IDs related to the community (optional)."""

text_unit_ids: list[str] | None = None
"""List of text unit IDs related to the community (optional)."""

covariate_ids: dict[str, list[str]] | None = None
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""

Expand All @@ -50,6 +53,7 @@ def from_dict(
level_key: str = "level",
entities_key: str = "entity_ids",
relationships_key: str = "relationship_ids",
text_units_key: str = "text_unit_ids",
covariates_key: str = "covariate_ids",
parent_key: str = "parent",
children_key: str = "children",
Expand All @@ -67,6 +71,7 @@ def from_dict(
short_id=d.get(short_id_key),
entity_ids=d.get(entities_key),
relationship_ids=d.get(relationships_key),
text_unit_ids=d.get(text_units_key),
covariate_ids=d.get(covariates_key),
attributes=d.get(attributes_key),
size=d.get(size_key),
Expand Down
49 changes: 2 additions & 47 deletions graphrag/index/operations/build_noun_graph/build_noun_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
BaseNounPhraseExtractor,
)
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.index.utils.graphs import calculate_pmi_edge_weights
from graphrag.index.utils.hashing import gen_sha512_hash


Expand Down Expand Up @@ -127,52 +128,6 @@ def _extract_edges(
]
if normalize_edge_weights:
# use PMI weight instead of raw weight
grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
grouped_edge_df = calculate_pmi_edge_weights(nodes_df, grouped_edge_df)

return grouped_edge_df


def _calculate_pmi_edge_weights(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
node_name_col="title",
node_freq_col="frequency",
edge_weight_col="weight",
edge_source_col="source",
edge_target_col="target",
) -> pd.DataFrame:
"""
Calculate pointwise mutual information (PMI) edge weights.

pmi(x,y) = log2(p(x,y) / (p(x)p(y)))
p(x,y) = edge_weight(x,y) / total_edge_weights
p(x) = freq_occurrence(x) / total_freq_occurrences
"""
copied_nodes_df = nodes_df[[node_name_col, node_freq_col]]

total_edge_weights = edges_df[edge_weight_col].sum()
total_freq_occurrences = nodes_df[node_freq_col].sum()
copied_nodes_df["prop_occurrence"] = (
copied_nodes_df[node_freq_col] / total_freq_occurrences
)
copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]]

edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights
edges_df = (
edges_df.merge(
copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left"
)
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "source_prop"})
)
edges_df = (
edges_df.merge(
copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left"
)
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "target_prop"})
)
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
)
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])
4 changes: 1 addition & 3 deletions graphrag/index/operations/cluster_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging

import networkx as nx
from graspologic.partition import hierarchical_leiden
Comment thread
natoverse marked this conversation as resolved.

from graphrag.index.utils.stable_lcc import stable_largest_connected_component

Expand Down Expand Up @@ -60,9 +61,6 @@ def _compute_leiden_communities(
seed: int | None = None,
) -> tuple[dict[int, dict[str, int]], dict[int, int]]:
"""Return Leiden root communities and their hierarchy mapping."""
# NOTE: This import is done here to reduce the initial import time of the graphrag package
from graspologic.partition import hierarchical_leiden

if use_lcc:
graph = stable_largest_connected_component(graph)

Expand Down
Loading
Loading