Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20250519234123676262.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Allow injection of custom pipelines."
}
11 changes: 9 additions & 2 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

async def build_index(
config: GraphRagConfig,
method: IndexingMethod = IndexingMethod.Standard,
method: IndexingMethod | str = IndexingMethod.Standard,
is_update_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
Expand Down Expand Up @@ -65,7 +65,9 @@ async def build_index(
if memory_profile:
log.warning("New pipeline does not yet support memory profiling.")

pipeline = PipelineFactory.create_pipeline(config, method, is_update_run)
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
method = _get_method(method, is_update_run)
pipeline = PipelineFactory.create_pipeline(config, method)

workflow_callbacks.pipeline_start(pipeline.names())

Expand All @@ -90,3 +92,8 @@ async def build_index(
def register_workflow_function(name: str, workflow: WorkflowFunction):
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
PipelineFactory.register(name, workflow)


def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
m = method.value if isinstance(method, IndexingMethod) else method
return f"{m}-update" if is_update_run else m
2 changes: 0 additions & 2 deletions graphrag/api/prompt_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
async def generate_indexing_prompts(
config: GraphRagConfig,
logger: ProgressLogger,
root: str,
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
overlap: Annotated[
int, annotated_types.Gt(-1)
Expand Down Expand Up @@ -93,7 +92,6 @@ async def generate_indexing_prompts(
# Retrieve documents
logger.info("Chunking documents...")
doc_list = await load_docs_in_chunks(
root=root,
config=config,
limit=limit,
select_method=selection_method,
Expand Down
1 change: 0 additions & 1 deletion graphrag/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def index_cli(
cli_overrides["reporting.base_dir"] = str(output_dir)
cli_overrides["update_index_output.base_dir"] = str(output_dir)
config = load_config(root_dir, config_filepath, cli_overrides)

_run_index(
config=config,
method=method,
Expand Down
1 change: 0 additions & 1 deletion graphrag/cli/prompt_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ async def prompt_tune(

prompts = await api.generate_indexing_prompts(
config=graph_config,
root=str(root_path),
logger=progress_logger,
chunk_size=chunk_size,
overlap=overlap,
Expand Down
42 changes: 24 additions & 18 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
CacheType,
ChunkStrategyType,
InputFileType,
InputType,
ModelType,
NounPhraseExtractorType,
OutputType,
ReportingType,
StorageType,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
Expand Down Expand Up @@ -234,16 +233,31 @@ class GlobalSearchDefaults:
chat_model_id: str = DEFAULT_CHAT_MODEL_ID


@dataclass
class StorageDefaults:
"""Default values for storage."""

type = StorageType.file
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None
cosmosdb_account_url: None = None


@dataclass
class InputStorageDefaults(StorageDefaults):
"""Default values for input storage."""

base_dir: str = "input"


@dataclass
class InputDefaults:
"""Default values for input."""

type = InputType.file
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
file_type = InputFileType.text
base_dir: str = "input"
connection_string: None = None
storage_account_blob_url: None = None
container_name: None = None
encoding: str = "utf-8"
file_pattern: str = ""
file_filter: None = None
Expand Down Expand Up @@ -301,15 +315,10 @@ class LocalSearchDefaults:


@dataclass
class OutputDefaults:
class OutputDefaults(StorageDefaults):
"""Default values for output."""

type = OutputType.file
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None
cosmosdb_account_url: None = None


@dataclass
Expand Down Expand Up @@ -364,14 +373,10 @@ class UmapDefaults:


@dataclass
class UpdateIndexOutputDefaults:
class UpdateIndexOutputDefaults(StorageDefaults):
"""Default values for update index output."""

type = OutputType.file
base_dir: str = "update_output"
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None


@dataclass
Expand All @@ -395,6 +400,7 @@ class GraphRagConfigDefaults:
root_dir: str = ""
models: dict = field(default_factory=dict)
reporting: ReportingDefaults = field(default_factory=ReportingDefaults)
storage: StorageDefaults = field(default_factory=StorageDefaults)
output: OutputDefaults = field(default_factory=OutputDefaults)
outputs: None = None
update_index_output: UpdateIndexOutputDefaults = field(
Expand Down
19 changes: 5 additions & 14 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,7 @@ def __repr__(self):
return f'"{self.value}"'


class InputType(str, Enum):
"""The input type for the pipeline."""

file = "file"
"""The file storage type."""
blob = "blob"
"""The blob storage type."""

def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'


class OutputType(str, Enum):
class StorageType(str, Enum):
"""The output type for the pipeline."""

file = "file"
Expand Down Expand Up @@ -152,6 +139,10 @@ class IndexingMethod(str, Enum):
"""Traditional GraphRAG indexing, with all graph construction and summarization performed by a language model."""
Fast = "fast"
"""Fast indexing, using NLP for graph construction and language model for summarization."""
StandardUpdate = "standard-update"
"""Incremental update with standard indexing."""
FastUpdate = "fast-update"
"""Incremental update with fast indexing."""


class NounPhraseExtractorType(str, Enum):
Expand Down
6 changes: 4 additions & 2 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@
### Input settings ###

input:
type: {graphrag_config_defaults.input.type.value} # or blob
storage:
type: {graphrag_config_defaults.input.storage.type.value} # or blob
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
base_dir: "{graphrag_config_defaults.input.base_dir}"


chunks:
size: {graphrag_config_defaults.chunks.size}
Expand Down
30 changes: 20 additions & 10 deletions graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.local_search_config import LocalSearchConfig
from graphrag.config.models.output_config import OutputConfig
from graphrag.config.models.prune_graph_config import PruneGraphConfig
from graphrag.config.models.reporting_config import ReportingConfig
from graphrag.config.models.snapshots_config import SnapshotsConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
Expand Down Expand Up @@ -102,29 +102,39 @@ def _validate_input_pattern(self) -> None:
else:
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"

def _validate_input_base_dir(self) -> None:
"""Validate the input base directory."""
if self.input.storage.type == defs.StorageType.file:
if self.input.storage.base_dir.strip() == "":
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
raise ValueError(msg)
self.input.storage.base_dir = str(
(Path(self.root_dir) / self.input.storage.base_dir).resolve()
)

chunks: ChunkingConfig = Field(
description="The chunking configuration to use.",
default=ChunkingConfig(),
)
"""The chunking configuration to use."""

output: OutputConfig = Field(
output: StorageConfig = Field(
description="The output configuration.",
default=OutputConfig(),
default=StorageConfig(),
)
"""The output configuration."""

def _validate_output_base_dir(self) -> None:
"""Validate the output base directory."""
if self.output.type == defs.OutputType.file:
if self.output.type == defs.StorageType.file:
if self.output.base_dir.strip() == "":
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
raise ValueError(msg)
self.output.base_dir = str(
(Path(self.root_dir) / self.output.base_dir).resolve()
)

outputs: dict[str, OutputConfig] | None = Field(
outputs: dict[str, StorageConfig] | None = Field(
description="A list of output configurations used for multi-index query.",
default=graphrag_config_defaults.outputs,
)
Expand All @@ -133,26 +143,25 @@ def _validate_multi_output_base_dirs(self) -> None:
"""Validate the outputs dict base directories."""
if self.outputs:
for output in self.outputs.values():
if output.type == defs.OutputType.file:
if output.type == defs.StorageType.file:
if output.base_dir.strip() == "":
msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
raise ValueError(msg)
output.base_dir = str(
(Path(self.root_dir) / output.base_dir).resolve()
)

update_index_output: OutputConfig = Field(
update_index_output: StorageConfig = Field(
description="The output configuration for the updated index.",
default=OutputConfig(
type=graphrag_config_defaults.update_index_output.type,
default=StorageConfig(
base_dir=graphrag_config_defaults.update_index_output.base_dir,
),
)
"""The output configuration for the updated index."""

def _validate_update_index_output_base_dir(self) -> None:
"""Validate the update index output base directory."""
if self.update_index_output.type == defs.OutputType.file:
if self.update_index_output.type == defs.StorageType.file:
if self.update_index_output.base_dir.strip() == "":
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
raise ValueError(msg)
Expand Down Expand Up @@ -345,6 +354,7 @@ def _validate_model(self):
self._validate_root_dir()
self._validate_models()
self._validate_input_pattern()
self._validate_input_base_dir()
self._validate_reporting_base_dir()
self._validate_output_base_dir()
self._validate_multi_output_base_dirs()
Expand Down
27 changes: 7 additions & 20 deletions graphrag/config/models/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,23 @@

import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import InputFileType, InputType
from graphrag.config.enums import InputFileType
from graphrag.config.models.storage_config import StorageConfig


class InputConfig(BaseModel):
"""The default configuration section for Input."""

type: InputType = Field(
description="The input type to use.",
default=graphrag_config_defaults.input.type,
storage: StorageConfig = Field(
description="The storage configuration to use for reading input documents.",
default=StorageConfig(
base_dir=graphrag_config_defaults.input.storage.base_dir,
),
)
file_type: InputFileType = Field(
description="The input file type to use.",
default=graphrag_config_defaults.input.file_type,
)
base_dir: str = Field(
description="The input base directory to use.",
default=graphrag_config_defaults.input.base_dir,
)
connection_string: str | None = Field(
description="The azure blob storage connection string to use.",
default=graphrag_config_defaults.input.connection_string,
)
storage_account_blob_url: str | None = Field(
description="The storage account blob url to use.",
default=graphrag_config_defaults.input.storage_account_blob_url,
)
container_name: str | None = Field(
description="The azure blob storage container name to use.",
default=graphrag_config_defaults.input.container_name,
)
encoding: str = Field(
description="The input file encoding to use.",
default=defs.graphrag_config_defaults.input.encoding,
Expand Down
38 changes: 0 additions & 38 deletions graphrag/config/models/output_config.py

This file was deleted.

Loading
Loading