Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250319182609055856.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Added batching logic to the prompt tuning autoselection embeddings workflow"
}
2 changes: 1 addition & 1 deletion graphrag/api/prompt_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def generate_indexing_prompts(

# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
# to be made or fallback to a default value in the worst case
if default_llm_settings.max_retries == -1:
if default_llm_settings.max_retries < -1:
Comment thread
nievespg1 marked this conversation as resolved.
default_llm_settings.max_retries = min(
len(doc_list), language_model_defaults.max_retries
)
Expand Down
51 changes: 26 additions & 25 deletions graphrag/prompt_tune/loader/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import numpy as np
import pandas as pd

from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import create_input
from graphrag.index.operations.embed_text.strategies.openai import (
run as run_embed_text,
)
from graphrag.index.workflows.create_base_text_units import create_base_text_units
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.base import ProgressLogger
from graphrag.prompt_tune.defaults import (
LIMIT,
Expand All @@ -21,20 +23,9 @@
from graphrag.prompt_tune.types import DocSelectionType


async def _embed_chunks(
text_chunks: pd.DataFrame,
embedding_llm: EmbeddingModel,
n_subset_max: int = N_SUBSET_MAX,
) -> tuple[pd.DataFrame, np.ndarray]:
"""Convert text chunks into dense text embeddings."""
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
embeddings = await embedding_llm.aembed_batch(sampled_text_chunks["text"].tolist())
return text_chunks, np.array(embeddings)


def _sample_chunks_from_embeddings(
text_chunks: pd.DataFrame,
embeddings,
embeddings: np.ndarray[float, np.dtype[np.float_]],
k: int = K,
) -> pd.DataFrame:
"""Sample text chunks from embeddings."""
Expand All @@ -60,7 +51,6 @@ async def load_docs_in_chunks(
embeddings_llm_settings = config.get_language_model_config(
config.embed_text.model_id
)

dataset = await create_input(config.input, logger, root)
chunk_config = config.chunks
chunks_df = create_base_text_units(
Expand Down Expand Up @@ -88,18 +78,29 @@ async def load_docs_in_chunks(
if k is None or k <= 0:
msg = "k must be an integer > 0"
raise ValueError(msg)
embedding_llm = ModelManager().register_embedding(
name="prompt_tuning_embeddings",
model_type=embeddings_llm_settings.type,
config=embeddings_llm_settings,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)

chunks_df, embeddings = await _embed_chunks(
chunks_df, embedding_llm, n_subset_max=n_subset_max
"""Convert text chunks into dense text embeddings."""
sampled_text_chunks = chunks_df.sample(n=min(n_subset_max, len(chunks_df)))[
"text"
].tolist()

embedding_results = await run_embed_text(
sampled_text_chunks,
callbacks=NoopWorkflowCallbacks(),
cache=NoopPipelineCache(),
args={
"llm": embeddings_llm_settings.model_dump(),
"num_threads": embeddings_llm_settings.concurrent_requests,
"batch_size": config.embed_text.batch_size,
"batch_max_tokens": config.embed_text.batch_max_tokens,
},
)
embeddings = np.array(embedding_results.embeddings)
chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k)

# Convert the dataset to list form, so we have a list of documents
return chunks_df["text"].tolist()
return [
# need this to prevent the str.format() function from breaking when parsing LaTeX from markdown files
i.replace("{", "{{").replace("}", "}}")
for i in chunks_df["text"]
]
Loading