diff --git a/.semversioner/next-release/patch-20250319182609055856.json b/.semversioner/next-release/patch-20250319182609055856.json new file mode 100644 index 0000000000..87cc8a25eb --- /dev/null +++ b/.semversioner/next-release/patch-20250319182609055856.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Added batching logic to the prompt tuning autoselection embeddings workflow" +} diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index 0c09714825..92f6ba94f3 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -111,7 +111,7 @@ async def generate_indexing_prompts( # if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls # to be made or fallback to a default value in the worst case - if default_llm_settings.max_retries == -1: + if default_llm_settings.max_retries < -1: default_llm_settings.max_retries = min( len(doc_list), language_model_defaults.max_retries ) diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index 17af53ec98..fa49ebeeb9 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -6,12 +6,14 @@ import numpy as np import pandas as pd +from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.input.factory import create_input +from graphrag.index.operations.embed_text.strategies.openai import ( + run as run_embed_text, +) from graphrag.index.workflows.create_base_text_units import create_base_text_units -from graphrag.language_model.manager import ModelManager -from graphrag.language_model.protocol.base import EmbeddingModel from graphrag.logger.base import ProgressLogger from graphrag.prompt_tune.defaults import ( LIMIT, @@ -21,20 +23,9 @@ from graphrag.prompt_tune.types import DocSelectionType -async def _embed_chunks( - text_chunks: pd.DataFrame, - embedding_llm: EmbeddingModel, - n_subset_max: int = N_SUBSET_MAX, -) -> tuple[pd.DataFrame, np.ndarray]: - """Convert text chunks into dense text embeddings.""" - sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks))) - embeddings = await embedding_llm.aembed_batch(sampled_text_chunks["text"].tolist()) - return text_chunks, np.array(embeddings) - - def _sample_chunks_from_embeddings( text_chunks: pd.DataFrame, - embeddings, + embeddings: np.ndarray[float, np.dtype[np.float_]], k: int = K, ) -> pd.DataFrame: """Sample text chunks from embeddings.""" @@ -60,7 +51,6 @@ async def load_docs_in_chunks( embeddings_llm_settings = config.get_language_model_config( config.embed_text.model_id ) - dataset = await create_input(config.input, logger, root) chunk_config = config.chunks chunks_df = create_base_text_units( @@ -88,18 +78,29 @@ async def load_docs_in_chunks( if k is None or k <= 0: msg = "k must be an integer > 0" raise ValueError(msg) - embedding_llm = ModelManager().register_embedding( - name="prompt_tuning_embeddings", - model_type=embeddings_llm_settings.type, - config=embeddings_llm_settings, - callbacks=NoopWorkflowCallbacks(), - cache=None, - ) - chunks_df, embeddings = await _embed_chunks( - chunks_df, embedding_llm, n_subset_max=n_subset_max + """Convert text chunks into dense text embeddings.""" + sampled_text_chunks = chunks_df.sample(n=min(n_subset_max, len(chunks_df)))[ + "text" + ].tolist() + + embedding_results = await run_embed_text( + sampled_text_chunks, + callbacks=NoopWorkflowCallbacks(), + cache=NoopPipelineCache(), + args={ + "llm": embeddings_llm_settings.model_dump(), + "num_threads": embeddings_llm_settings.concurrent_requests, + "batch_size": config.embed_text.batch_size, + "batch_max_tokens": config.embed_text.batch_max_tokens, + }, ) + embeddings = np.array(embedding_results.embeddings) chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k) # Convert the dataset to list form, so we have a list of documents - return chunks_df["text"].tolist() + return [ + # need this to prevent the str.format() function from breaking when parsing LaTeX from markdown files + i.replace("{", "{{").replace("}", "}}") + for i in chunks_df["text"] + ]