Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"--root",
"${input:root_folder}",
"--method", "${input:query_method}",
"--query", "${input:query}"
"${input:query}"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions docs/examples_notebooks/api_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
"from pathlib import Path\n",
"from pprint import pprint\n",
"\n",
"import graphrag.api as api\n",
"import pandas as pd\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
"\n",
"import graphrag.api as api"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions docs/examples_notebooks/input_documents.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
"from pathlib import Path\n",
"from pprint import pprint\n",
"\n",
"import graphrag.api as api\n",
"import pandas as pd\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
"\n",
"import graphrag.api as api"
]
},
{
Expand Down
52 changes: 48 additions & 4 deletions packages/graphrag/graphrag/query/context_builder/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ class ContextBuilderResult:
output_tokens: int = 0


@dataclass
class LLMParameters:
"""A class to hold LLM call parameters."""

llm_calls: int = 0
prompt_tokens: int = 0
output_tokens: int = 0


class GlobalContextBuilder(ABC):
"""Base class for global-search context builders."""

Expand All @@ -36,6 +45,22 @@ async def build_context(
) -> ContextBuilderResult:
"""Build the context for the global search mode."""

@abstractmethod
async def build_context_chunks(
self,
query: str,
**kwargs,
) -> str | list[str]:
"""Build the context chunks for the global search mode."""

@abstractmethod
async def build_context_records(
self,
query: str,
**kwargs,
) -> dict[str, pd.DataFrame]:
"""Build the context records for the global search mode."""


class LocalContextBuilder(ABC):
"""Base class for local-search context builders."""
Expand All @@ -49,6 +74,22 @@ def build_context(
) -> ContextBuilderResult:
"""Build the context for the local search mode."""

@abstractmethod
def build_context_chunks(
self,
query: str,
**kwargs,
) -> str:
"""Build the context chunks for the local search mode."""

@abstractmethod
def build_context_records(
self,
query: str,
**kwargs,
) -> dict[str, pd.DataFrame]:
"""Build the context records for the local search mode."""


class DRIFTContextBuilder(ABC):
"""Base class for DRIFT-search context builders."""
Expand All @@ -66,10 +107,13 @@ class BasicContextBuilder(ABC):
"""Base class for basic-search context builders."""

@abstractmethod
def build_context(
def build_context_records(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> ContextBuilderResult:
"""Build the context for the basic search mode."""
) -> dict[str, pd.DataFrame]:
"""Build the context records for the basic search mode."""

@abstractmethod
def get_llm_values(self) -> LLMParameters:
"""Get the LLM call values."""
10 changes: 10 additions & 0 deletions packages/graphrag/graphrag/query/structured_search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,13 @@ async def stream_search(
yield "" # This makes it an async generator.
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)

@abstractmethod
async def format_records(
self,
records: dict[str, pd.DataFrame],
column_delimiter: str = "|",
) -> str | list[str]:
"""Format context records into a string representation."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.context_builder.builders import (
BasicContextBuilder,
ContextBuilderResult,
LLMParameters,
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
Expand All @@ -39,18 +38,17 @@ def __init__(
self.text_unit_embeddings = text_unit_embeddings
self.embedding_vectorstore_key = embedding_vectorstore_key

def build_context(
def build_context_records(
self,
query: str,
conversation_history: ConversationHistory | None = None,
k: int = 10,
max_context_tokens: int = 12_000,
context_name: str = "Sources",
column_delimiter: str = "|",
text_id_col: str = "id",
text_col: str = "text",
**kwargs,
) -> ContextBuilderResult:
) -> dict[str, pd.DataFrame]:
"""Build the context for the basic search mode."""
if query != "":
related_texts = self.text_unit_embeddings.similarity_search_by_text(
Expand Down Expand Up @@ -95,11 +93,13 @@ def build_context(
drop=True
),
)
final_text = final_text_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)

return ContextBuilderResult(
context_chunks=final_text,
context_records={context_name.lower(): final_text_df},
return {context_name.lower(): final_text_df}

def get_llm_values(self) -> LLMParameters:
"""Get the LLM call values."""
return LLMParameters(
llm_calls=0,
prompt_tokens=0,
output_tokens=0,
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
)
from graphrag.query.context_builder.builders import BasicContextBuilder
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.base import (
BaseSearch,
SearchResult,
)
from graphrag.tokenizer.tokenizer import Tokenizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,6 +52,17 @@ def __init__(
self.callbacks = callbacks or []
self.response_type = response_type

async def format_records(self, records, column_delimiter="|") -> str | list[str]:
"""Format context records into a string representation."""
if len(records) == 1:
_, context_records_df = next(iter(records.items()))

if context_records_df is not None:
return context_records_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ""

async def search(
self,
query: str,
Expand All @@ -59,22 +73,30 @@ async def search(
start_time = time.time()
search_prompt = ""
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_chunks: str | list[str] = ""
column_delimiter: str = "|"

context_result = self.context_builder.build_context(
context_records = self.context_builder.build_context_records(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)

llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
llm_values = self.context_builder.get_llm_values()
llm_calls["build_context"] = llm_values.llm_calls
prompt_tokens["build_context"] = llm_values.prompt_tokens
output_tokens["build_context"] = llm_values.output_tokens

logger.debug("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
try:
context_chunks = await self.format_records(
records=context_records,
column_delimiter=column_delimiter,
)

search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks,
context_data=context_chunks,
response_type=self.response_type,
)
search_messages = [
Expand All @@ -96,12 +118,12 @@ async def search(
output_tokens["response"] = len(self.tokenizer.encode(response))

for callback in self.callbacks:
callback.on_context(context_result.context_records)
callback.on_context(context_records)

return SearchResult(
response=response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
context_data=context_records,
context_text=context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
Expand All @@ -115,8 +137,8 @@ async def search(
logger.exception("Exception in _asearch")
return SearchResult(
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
context_data=context_records,
context_text=context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=len(self.tokenizer.encode(search_prompt)),
Expand All @@ -133,22 +155,31 @@ async def stream_search(
) -> AsyncGenerator[str, None]:
"""Build basic search context that fits a single context window and generate answer for the user query."""
start_time = time.time()
context_chunks: str | list[str] = ""
column_delimiter: str = "|"

context_result = self.context_builder.build_context(
context_records = self.context_builder.build_context_records(
query=query,
conversation_history=conversation_history,
**self.context_builder_params,
)

logger.debug("GENERATE ANSWER: %s. QUERY: %s", start_time, query)

context_chunks = await self.format_records(
records=context_records,
column_delimiter=column_delimiter,
)

search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks, response_type=self.response_type
context_data=context_chunks, response_type=self.response_type
)
search_messages = [
{"role": "system", "content": search_prompt},
]

for callback in self.callbacks:
callback.on_context(context_result.context_records)
callback.on_context(context_records)

async for chunk_response in self.model.achat_stream(
prompt=query,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ async def search(
output_tokens_categories=output_tokens,
)

async def format_records(self, records, column_delimiter="|") -> str | list[str]:
"""Format context records into a string representation."""
if len(records) == 1:
_, context_records_df = next(iter(records.items()))

if context_records_df is not None:
return context_records_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ""

async def stream_search(
self, query: str, conversation_history: ConversationHistory | None = None
) -> AsyncGenerator[str, None]:
Expand Down
Loading
Loading