From 9ac86eac750fa335424cead720952a21717432ad Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Fri, 20 Feb 2026 23:27:32 +0530 Subject: [PATCH 01/11] LLM Chain: Add foundation for chain execution with database schema --- .../versions/048_create_llm_chain_table.py | 181 +++++ backend/app/api/main.py | 2 + backend/app/api/routes/llm_chain.py | 62 ++ backend/app/crud/llm.py | 2 + backend/app/crud/llm_chain.py | 151 ++++ backend/app/models/__init__.py | 3 + backend/app/models/job.py | 1 + backend/app/models/llm/__init__.py | 9 + backend/app/models/llm/request.py | 231 ++++++ backend/app/models/llm/response.py | 39 + backend/app/services/llm/chain/__init__.py | 0 backend/app/services/llm/chain/chain.py | 221 ++++++ backend/app/services/llm/chain/executor.py | 197 +++++ backend/app/services/llm/chain/types.py | 18 + backend/app/services/llm/jobs.py | 700 ++++++++++++------ 15 files changed, 1570 insertions(+), 247 deletions(-) create mode 100644 backend/app/alembic/versions/048_create_llm_chain_table.py create mode 100644 backend/app/api/routes/llm_chain.py create mode 100644 backend/app/crud/llm_chain.py create mode 100644 backend/app/services/llm/chain/__init__.py create mode 100644 backend/app/services/llm/chain/chain.py create mode 100644 backend/app/services/llm/chain/executor.py create mode 100644 backend/app/services/llm/chain/types.py diff --git a/backend/app/alembic/versions/048_create_llm_chain_table.py b/backend/app/alembic/versions/048_create_llm_chain_table.py new file mode 100644 index 000000000..ac49eb0ec --- /dev/null +++ b/backend/app/alembic/versions/048_create_llm_chain_table.py @@ -0,0 +1,181 @@ +"""Create llm_chain table + +Revision ID: 048 +Revises: 047 +Create Date: 2026-02-20 00:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB + +revision = "048" +down_revision = "047" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # 1. Create llm_chain table + op.create_table( + "llm_chain", + sa.Column( + "id", + sa.Uuid(), + nullable=False, + comment="Unique identifier for the LLM chain record", + ), + sa.Column( + "job_id", + sa.Uuid(), + nullable=False, + comment="Reference to the parent job (status tracked in job table)", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project this LLM call belongs to", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization this LLM call belongs to", + ), + sa.Column( + "status", + sa.String(), + nullable=False, + server_default="pending", + comment="Chain execution status (pending, running, failed, completed)", + ), + sa.Column( + "error", + sa.Text(), + nullable=True, + comment="Error message if the chain execution failed", + ), + sa.Column( + "block_sequences", + JSONB(), + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + sa.Column( + "total_blocks", + sa.Integer(), + nullable=False, + comment="Total number of blocks to execute", + ), + sa.Column( + "number_of_blocks_processed", + sa.Integer(), + nullable=False, + server_default="0", + comment="Number of blocks processed so far (used for tracking progress)", + ), + sa.Column( + "input", + sa.String(), + nullable=False, + comment="First block user's input - text string, binary data, or file path for multimodal", + ), + sa.Column( + "output", + JSONB(), + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + sa.Column( + "configs", + JSONB(), + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + sa.Column( + "total_usage", + JSONB(), + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + sa.Column( + "metadata", + JSONB(), + nullable=True, + comment="Future-proof extensibility catch-all", + ), + sa.Column( + "started_at", + sa.DateTime(), + nullable=True, + comment="Timestamp when chain execution started", + ), + sa.Column( + "completed_at", + sa.DateTime(), + nullable=True, + comment="Timestamp when chain execution completed", + ), + sa.Column( + "created_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was last updated", + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["job_id"], ["job.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + ) + + op.create_index( + "idx_llm_chain_job_id", + "llm_chain", + ["job_id"], + ) + + # 2. Add chain_id FK column to llm_call table + op.add_column( + "llm_call", + sa.Column( + "chain_id", + sa.Uuid(), + nullable=True, + comment="Reference to the parent chain (NULL for standalone /llm/call requests)", + ), + ) + op.create_foreign_key( + "fk_llm_call_chain_id", + "llm_call", + "llm_chain", + ["chain_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_index( + "idx_llm_call_chain_id", + "llm_call", + ["chain_id"], + postgresql_where=sa.text("chain_id IS NOT NULL"), + ) + + op.execute("ALTER TYPE jobtype ADD VALUE IF NOT EXISTS 'LLM_CHAIN'") + + +def downgrade() -> None: + op.drop_index("idx_llm_call_chain_id", table_name="llm_call") + op.drop_constraint("fk_llm_call_chain_id", "llm_call", type_="foreignkey") + op.drop_column("llm_call", "chain_id") + + op.drop_index("idx_llm_chain_job_id", table_name="llm_chain") + op.drop_table("llm_chain") diff --git a/backend/app/api/main.py b/backend/app/api/main.py index ed58e57f2..5ab1cbd9e 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -10,6 +10,7 @@ login, languages, llm, + llm_chain, organization, openai_conversation, project, @@ -41,6 +42,7 @@ api_router.include_router(evaluations.router) api_router.include_router(languages.router) api_router.include_router(llm.router) +api_router.include_router(llm_chain.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) api_router.include_router(openai_conversation.router) diff --git a/backend/app/api/routes/llm_chain.py b/backend/app/api/routes/llm_chain.py new file mode 100644 index 000000000..0634c2038 --- /dev/null +++ b/backend/app/api/routes/llm_chain.py @@ -0,0 +1,62 @@ +import logging + +from fastapi import APIRouter, Depends +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.models import LLMChainRequest, LLMChainResponse, Message +from app.services.llm.jobs import start_chain_job +from app.utils import APIResponse, validate_callback_url, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["LLM Chain"]) +llm_callback_router = APIRouter() + + +@llm_callback_router.post( + "{$callback_url}", + name="llm_chain_callback", +) +def llm_callback_notification(body: APIResponse[LLMChainResponse]): + """ + Callback endpoint specification for LLM chain completion. + + The callback will receive: + - On success: APIResponse with success=True and data containing LLMChainResponse + - On failure: APIResponse with success=False and error message + - metadata field will always be included if provided in the request + """ + ... + + +@router.post( + "/llm/chain", + description=load_description("llm/llm_call.md"), + response_model=APIResponse[Message], + callbacks=llm_callback_router.routes, + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def llm_chain( + _current_user: AuthContextDep, _session: SessionDep, request: LLMChainRequest +): + """ + Endpoint to initiate an LLM chain as a background job. + """ + project_id = _current_user.project_.id + organization_id = _current_user.organization_.id + + if request.callback_url: + validate_callback_url(str(request.callback_url)) + + start_chain_job( + db=_session, + request=request, + project_id=project_id, + organization_id=organization_id, + ) + + return APIResponse.success_response( + data=Message( + message="Your response is being generated and will be delivered via callback." + ), + ) diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index b5c23cd6e..32f8ca46f 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -53,6 +53,7 @@ def create_llm_call( *, request: LLMCallRequest, job_id: UUID, + chain_id: UUID | None = None, project_id: int, organization_id: int, resolved_config: ConfigBlob, @@ -120,6 +121,7 @@ def create_llm_call( job_id=job_id, project_id=project_id, organization_id=organization_id, + chain_id=chain_id, input=serialize_input(request.query.input), input_type=input_type, output_type=output_type, diff --git a/backend/app/crud/llm_chain.py b/backend/app/crud/llm_chain.py new file mode 100644 index 000000000..77ab70987 --- /dev/null +++ b/backend/app/crud/llm_chain.py @@ -0,0 +1,151 @@ +import logging +from typing import Any +from uuid import UUID + +from sqlmodel import Session + +from app.core.util import now +from app.models.llm.request import ChainStatus, LlmChain + +logger = logging.getLogger(__name__) + + +def create_llm_chain( + session: Session, + *, + job_id: UUID, + project_id: int, + organization_id: int, + total_blocks: int, + input: str, + configs: list[dict[str, Any]], +) -> LlmChain: + """Create a new LLM chain record. + Args: + session: Database session + job_id: Reference to the parent job + project_id: Reference to the project + organization_id: Reference to the organization + total_blocks: Total number of blocks to execute + input: Serialized input string (via serialize_input) + configs: Ordered list of block configs as submitted + + Returns: + LlmChain: The created chain record + """ + db_llm_chain = LlmChain( + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + status=ChainStatus.PENDING, + total_blocks=total_blocks, + number_of_blocks_processed=0, + input=input, + configs=configs, + block_sequences=[], + ) + + session.add(db_llm_chain) + session.commit() + session.refresh(db_llm_chain) + + logger.info( + f"[create_llm_chain] Created LLM chain id={db_llm_chain.id}, " + f"job_id={job_id}, total_blocks={total_blocks}" + ) + + return db_llm_chain + + +def update_llm_chain_status( + session: Session, + *, + chain_id: UUID, + status: ChainStatus, + output: dict[str, Any] | None = None, + total_usage: dict[str, Any] | None = None, + error: str | None = None, +) -> LlmChain: + """Update chain record status and related fields. + Args: + session: Database session + chain_id: The chain record ID + status: New chain status + output: Last block's output dict (only for COMPLETED) + total_usage: Aggregated token usage across all blocks (for COMPLETED/FAILED) + error: Error message (only for FAILED) + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + db_chain.status = status + db_chain.updated_at = now() + + if status == ChainStatus.RUNNING: + db_chain.started_at = now() + + if status == ChainStatus.FAILED: + db_chain.error = error + db_chain.total_usage = total_usage + db_chain.completed_at = now() + + if status == ChainStatus.COMPLETED: + db_chain.output = output + db_chain.total_usage = total_usage + db_chain.completed_at = now() + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_status] Chain {chain_id} → {status.value} | " + f"has_output={output is not None}, " + f"blocks={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"error={error}" + ) + return db_chain + + +def update_llm_chain_block_completed( + session: Session, + *, + chain_id: UUID, + llm_call_id: UUID, +) -> LlmChain: + """Update chain progress after a block completes. + Args: + session: Database session + chain_id: The chain record ID + llm_call_id: The llm_call record ID for the completed block + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + # Append to block_sequences + sequences = list(db_chain.block_sequences or []) + sequences.append(str(llm_call_id)) + db_chain.block_sequences = sequences + + # Increment progress + db_chain.number_of_blocks_processed = len(sequences) + db_chain.updated_at = now() + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_block_completed] Chain {chain_id} | " + f"block={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"llm_call_id={llm_call_id}" + ) + return db_chain diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2c28d7b4f..c76a02579 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -111,6 +111,9 @@ LLMCallRequest, LLMCallResponse, LlmCall, + LLMChainRequest, + LLMChainResponse, + LlmChain, ) from .message import Message diff --git a/backend/app/models/job.py b/backend/app/models/job.py index b6a1a5ae7..3b20249f5 100644 --- a/backend/app/models/job.py +++ b/backend/app/models/job.py @@ -17,6 +17,7 @@ class JobStatus(str, Enum): class JobType(str, Enum): RESPONSE = "RESPONSE" LLM_API = "LLM_API" + LLM_CHAIN = "LLM_CHAIN" class Job(SQLModel, table=True): diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index b183543c4..9bcf3a035 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -9,6 +9,13 @@ LlmCall, AudioContent, TextContent, + TextInput, + AudioInput, + PromptTemplate, + ChainBlock, + ChainStatus, + LLMChainRequest, + LlmChain, ) from app.models.llm.response import ( LLMCallResponse, @@ -17,4 +24,6 @@ Usage, TextOutput, AudioOutput, + LLMChainResponse, + IntermediateChainResponse, ) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index b90fb6229..d6abd7d8d 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Annotated, Any, Literal, Union from uuid import UUID, uuid4 @@ -214,11 +215,21 @@ class Validator(SQLModel): validator_config_id: UUID +class PromptTemplate(SQLModel): + template: str = Field(..., description="Template string with {{input}} placeholder") + + class ConfigBlob(SQLModel): """Raw JSON blob of config.""" completion: CompletionConfig = Field(..., description="Completion configuration") + # used for llm-chain to provide prompt interpolation + prompt_template: PromptTemplate | None = Field( + default=None, + description="Prompt template with {{input}} placeholder to wrap around the user input", + ) + input_guardrails: list[Validator] | None = Field( default=None, description="Guardrails applied to validate/sanitize the input before the LLM call", @@ -384,6 +395,16 @@ class LlmCall(SQLModel, table=True): }, ) + chain_id: UUID | None = Field( + default=None, + foreign_key="llm_chain.id", + nullable=True, + ondelete="SET NULL", + sa_column_kwargs={ + "comment": "Reference to the parent chain (NULL for standalone llm_call requests)" + }, + ) + # Request fields input: str = Field( ..., @@ -496,3 +517,213 @@ class LlmCall(SQLModel, table=True): nullable=True, sa_column_kwargs={"comment": "Timestamp when the record was soft-deleted"}, ) + + +class ChainBlock(SQLModel): + """A single block in an LLM chain execution.""" + + config: LLMCallConfig = Field( + ..., description="LLM call configuration (stored id+version OR ad-hoc blob)" + ) + + include_provider_raw_response: bool = Field( + default=False, + description="Whether to include the raw LLM provider response in the output for this block", + ) + + intermediate_callback: bool = Field( + default=False, + description="Whether to send intermediate callback after this block completes", + ) + + +class LLMChainRequest(SQLModel): + """ + API request for an LLM chain execution. + + Orchestrates multiple LLM calls sequentially where each block's output + becomes the next block's input. + """ + + query: QueryParams = Field( + ..., description="Initial query input for the first block in the chain" + ) + + blocks: list[ChainBlock] = Field( + ..., min_length=1, description="Ordered list of blocks to execute sequentially" + ) + + callback_url: HttpUrl | None = Field( + default=None, description="Webhook URL for async response delivery" + ) + + request_metadata: dict[str, Any] | None = Field( + default=None, + description=( + "Client-provided metadata passed through unchanged in the response. " + "Use this to correlate responses with requests or track request state. " + "The exact dictionary provided here will be returned in the response metadata field." + ), + ) + + +class ChainStatus(str, Enum): + """Status of an LLM chain execution.""" + + PENDING = "pending" + RUNNING = "running" + FAILED = "failed" + COMPLETED = "completed" + + +class LlmChain(SQLModel, table=True): + """ + Database model for tracking LLM chain execution + + it manages and orchestrates sequential llm_call executions. + """ + + __tablename__ = "llm_chain" + __table_args__ = ( + Index( + "idx_llm_chain_job_id", + "job_id", + ), + ) + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the LLM chain record"}, + ) + + job_id: UUID = Field( + foreign_key="job.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the parent job (status tracked in job table)" + }, + ) + + project_id: int = Field( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the project this LLM call belongs to" + }, + ) + + organization_id: int = Field( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the organization this LLM call belongs to" + }, + ) + + status: ChainStatus = Field( + default=ChainStatus.PENDING, + sa_column_kwargs={ + "comment": "Chain execution status (pending, running, failed, completed)" + }, + ) + + error: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Error message if the chain execution failed"}, + ) + + block_sequences: list[str] | None = Field( + default_factory=list, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + ) + + total_blocks: int = Field( + ..., sa_column_kwargs={"comment": "Total number of blocks to execute"} + ) + + number_of_blocks_processed: int = Field( + default=0, + sa_column_kwargs={ + "comment": "Number of blocks processed so far (used for tracking progress)" + }, + ) + + # Request fields + input: str = Field( + ..., + sa_column_kwargs={ + "comment": "First block user's input - text string, binary data, or file path for multimodal" + }, + ) + + output: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + ) + + configs: list[dict[str, Any]] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + ) + + total_usage: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + ) + + metadata_: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + "metadata", + JSONB, + nullable=True, + comment="Future-proof extensibility catch-all", + ), + ) + + started_at: datetime | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Timestamp when chain execution started"}, + ) + + completed_at: datetime | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Timestamp when chain execution completed"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the chain record was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the chain record was last updated" + }, + ) diff --git a/backend/app/models/llm/response.py b/backend/app/models/llm/response.py index 7b13e301c..1ae7619f6 100644 --- a/backend/app/models/llm/response.py +++ b/backend/app/models/llm/response.py @@ -62,3 +62,42 @@ class LLMCallResponse(SQLModel): default=None, description="Unmodified raw response from the LLM provider.", ) + + +class LLMChainResponse(SQLModel): + """Response schema for an LLM chain execution.""" + + response: LLMResponse = Field( + ..., description="LLM response from the final step of the chain execution." + ) + usage: Usage = Field( + ..., + description="Aggregate token usage and cost for the entire chain execution.", + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Raw provider response from the last block (if requested)", + ) + + +class IntermediateChainResponse(SQLModel): + """ + Intermediate callback response from the intermediate blocks + from the llm chain execution. (if configured) + + Flattend structure matching LLMCallResponse keys for consistency + """ + + type: Literal["intermediate"] = "intermediate" + block_index: int = Field(..., description="Current block position") + total_blocks: int = Field(..., description="Total number of blocks in the chain") + response: LLMResponse = Field( + ..., description="LLM Response from the current block" + ) + usage: Usage = Field( + ..., description="Token usage and cost information from the current block" + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Unmodified raw response from the LLM provider from the current block", + ) diff --git a/backend/app/services/llm/chain/__init__.py b/backend/app/services/llm/chain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py new file mode 100644 index 000000000..390247d8d --- /dev/null +++ b/backend/app/services/llm/chain/chain.py @@ -0,0 +1,221 @@ +import logging +from dataclasses import dataclass, field +from typing import Any +from uuid import UUID + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.llm_chain import update_llm_chain_block_completed +from app.models.llm.request import ( + LLMCallConfig, + QueryParams, + TextInput, + TextContent, + AudioInput, +) +from app.models.llm.response import ( + IntermediateChainResponse, + TextOutput, + AudioOutput, + Usage, +) +from app.services.llm.chain.types import BlockResult +from app.services.llm.jobs import execute_llm_call +from app.utils import APIResponse, send_callback + + +logger = logging.getLogger(__name__) + + +@dataclass +class ChainContext: + """Shared state passed to all blocks. Accumulates responses.""" + + job_id: UUID + chain_id: UUID + project_id: int + organization_id: int + callback_url: str + total_blocks: int + + langfuse_credentials: dict[str, Any] | None = None + request_metadata: dict | None = None + intermediate_callback_flags: list[bool] = field(default_factory=list) + aggregated_usage: Usage = field( + default_factory=lambda: Usage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + ) + ) + + def on_block_completed(self, block_index: int, result: BlockResult) -> None: + """Called after each block completes. Updates chain state in DB and sends intermediate callback.""" + + if result.usage: + self.aggregated_usage.input_tokens += result.usage.input_tokens + self.aggregated_usage.output_tokens += result.usage.output_tokens + self.aggregated_usage.total_tokens += result.usage.total_tokens + + if result.success and result.llm_call_id: + with Session(engine) as session: + update_llm_chain_block_completed( + session, + chain_id=self.chain_id, + llm_call_id=result.llm_call_id, + ) + + if ( + block_index < len(self.intermediate_callback_flags) + and self.intermediate_callback_flags[block_index] + and self.callback_url + ): + self._send_intermediate_callback(block_index, result) + + def _send_intermediate_callback( + self, block_index: int, result: BlockResult + ) -> None: + """Send intermediate callback for a completed block.""" + try: + intermediate = IntermediateChainResponse( + block_index=block_index + 1, + total_blocks=self.total_blocks, + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_data = APIResponse.success_response( + data=intermediate, + metadata=self.request_metadata, + ) + send_callback( + callback_url=self.callback_url, + data=callback_data.model_dump(), + ) + logger.info( + f"[ChainContext] Sent intermediate callback | " + f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" + ) + except Exception as e: + logger.warning( + f"[ChainContext] Failed to send intermediate callback: {e} | " + f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" + ) + + +def result_to_query(result: BlockResult) -> QueryParams: + """Convert a block's output into the next block's QueryParams. + + Text output → TextInput query + Audio output → AudioInput query + """ + output = result.response.response.output + + if isinstance(output, TextOutput): + return QueryParams( + input=TextInput(content=TextContent(value=output.content.value)) + ) + elif isinstance(output, AudioOutput): + return QueryParams(input=AudioInput(content=output.content)) + else: + raise ValueError(f"Cannot chain output type: {output.type}") + + +class ChainBlock: + """A single node in the linked chain. + + Wraps execute_block() with linking capability. + Each block knows its next block and forwards output to it. + """ + + def __init__( + self, + *, + config: LLMCallConfig, + index: int, + context: ChainContext, + include_provider_raw_response: bool = False, + ): + self._config = config + self._index = index + self._context = context + self._include_provider_raw_response = include_provider_raw_response + self._next: ChainBlock | None = None + + def link(self, next_block: "ChainBlock") -> "ChainBlock": + """Link to the next block in the chain.""" + self._next = next_block + return next_block + + def execute(self, query: QueryParams) -> BlockResult: + """Execute this block, then flow to next. + + No loop. Each block calls the next via the linked reference. + Data flows through the chain like a linked list traversal. + """ + logger.info( + f"[ChainBlock.execute] Executing block {self._index} | " + f"job_id={self._context.job_id}" + ) + + result = execute_llm_call( + config=self._config, + query=query, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + request_metadata=self._context.request_metadata, + langfuse_credentials=self._context.langfuse_credentials, + include_provider_raw_response=self._include_provider_raw_response, + chain_id=self._context.chain_id, + ) + + self._context.on_block_completed(self._index, result) + + if not result.success: + logger.error( + f"[ChainBlock.execute] Block {self._index} failed: {result.error} | " + f"job_id={self._context.job_id}" + ) + return result + + if self._next: + next_query = result_to_query(result) + return self._next.execute(next_query) + + logger.info( + f"[ChainBlock.execute] Block {self._index} is the last block | " + f"job_id={self._context.job_id}" + ) + return result + + +class LLMChain: + """Links ChainBlocks together into a sequential chain. + + Construction builds the linked structure. + Execution pushes input into the head — it flows through to the tail. + """ + + def __init__(self, blocks: list[ChainBlock]): + self._head: ChainBlock | None = None + self._tail: ChainBlock | None = None + self._link_blocks(blocks) + + def _link_blocks(self, blocks: list[ChainBlock]) -> None: + """Link all blocks in sequence.""" + if not blocks: + return + self._head = blocks[0] + self._tail = blocks[-1] + prev = blocks[0] + for curr in blocks[1:]: + prev.link(curr) + prev = curr + + def execute(self, query: QueryParams) -> BlockResult: + """Push input into the chain head. It flows through to the tail.""" + if not self._head: + return BlockResult(error="Chain has no blocks") + return self._head.execute(query) diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py new file mode 100644 index 000000000..78808d84c --- /dev/null +++ b/backend/app/services/llm/chain/executor.py @@ -0,0 +1,197 @@ +import logging + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.config import ConfigVersionCrud +from app.crud.jobs import JobCrud +from app.crud.llm_chain import update_llm_chain_status +from app.models import JobStatus, JobUpdate +from app.models.llm.request import ( + ChainStatus, + ConfigBlob, + LLMChainRequest, +) +from app.models.llm.response import LLMChainResponse +from app.services.llm.chain.chain import ChainContext, LLMChain +from app.services.llm.chain.types import BlockResult +from app.services.llm.jobs import ( + apply_input_guardrails, + apply_output_guardrails, + resolve_config_blob, +) +from app.utils import APIResponse, send_callback + +logger = logging.getLogger(__name__) + + +class ChainExecutor: + """Manage the lifecycle of an LLM chain execution.""" + + def __init__( + self, + *, + chain: LLMChain, + context: ChainContext, + request: LLMChainRequest, + ): + self._chain = chain + self._context = context + self._request = request + + def run(self) -> dict: + """Execute the full chain lifecycle. Returns serialized APIResponse.""" + try: + self._setup() + + first_config_blob, resolve_error = self._resolve_block_config_blob(0) + if resolve_error: + return self._handle_error(resolve_error) + + query, error = apply_input_guardrails( + config_blob=first_config_blob, + query=self._request.query, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + ) + if error: + return self._handle_error(error) + + result = self._chain.execute(query) + + if result.success: + last_config_blob, resolve_error = self._resolve_block_config_blob( + len(self._request.blocks) - 1 + ) + if resolve_error: + return self._handle_error(resolve_error) + + result, error = apply_output_guardrails( + config_blob=last_config_blob, + result=result, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + ) + if error: + return self._handle_error(error) + + return self._teardown(result) + + except Exception as e: + return self._handle_unexpected_error(e) + + def _resolve_block_config_blob( + self, block_index: int + ) -> tuple[ConfigBlob | None, str | None]: + """Resolve a block's config to its ConfigBlob. + + Uses is_stored_config property (same pattern as execute_job in jobs.py): + - Stored config (is_stored_config=True): fetch from DB via resolve_config_blob() + - Ad-hoc config (blob provided): return blob directly + + Returns: + (config_blob, error): ConfigBlob on success, or error string on failure + """ + block = self._request.blocks[block_index] + config = block.config + + if not config.is_stored_config: + return config.blob, None + + with Session(engine) as session: + config_crud = ConfigVersionCrud( + session=session, + project_id=self._context.project_id, + config_id=config.id, + ) + config_blob, error = resolve_config_blob(config_crud, config) + if error: + return None, error + return config_blob, None + + def _setup(self) -> None: + with Session(engine) as session: + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.PROCESSING), + ) + + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.RUNNING, + ) + + def _teardown(self, result: BlockResult) -> dict: + """Finalize chain record, send callback, and update job status.""" + + with Session(engine) as session: + if result.success: + final = LLMChainResponse( + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_response = APIResponse.success_response( + data=final, metadata=self._request.request_metadata + ) + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.SUCCESS), + ) + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.COMPLETED, + output=result.response.response.output.model_dump(), + total_usage=self._context.aggregated_usage.model_dump(), + ) + return callback_response.model_dump() + else: + return self._handle_error(result.error) + + def _handle_error(self, error: str) -> dict: + callback_response = APIResponse.failure_response( + error=error or "Unknown error occurred", + metadata=self._request.request_metadata, + ) + logger.error( + f"[ChainExecutor] Chain execution failed | " + f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}" + ) + + with Session(engine) as session: + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + + update_llm_chain_status( + session, + chain_id=self._context.chain_id, + status=ChainStatus.FAILED, + output=None, + total_usage=self._context.aggregated_usage.model_dump(), + error=error, + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.FAILED, error_message=error), + ) + return callback_response.model_dump() + + def _handle_unexpected_error(self, e: Exception) -> dict: + logger.error( + f"[ChainExecutor.run] Unexpected error: {e} | " + f"job_id={self._context.job_id}", + exc_info=True, + ) + return self._handle_error("Unexpected error occurred") diff --git a/backend/app/services/llm/chain/types.py b/backend/app/services/llm/chain/types.py new file mode 100644 index 000000000..69ab3d02f --- /dev/null +++ b/backend/app/services/llm/chain/types.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from uuid import UUID + +from app.models.llm.response import LLMCallResponse, Usage + + +@dataclass +class BlockResult: + """Result of a single block/LLM call execution.""" + + response: LLMCallResponse | None = None + llm_call_id: UUID | None = None + usage: Usage | None = None + error: str | None = None + + @property + def success(self) -> bool: + return self.error is None and self.response is not None diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index c6997a084..cd71e5bfa 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -11,23 +11,26 @@ from app.crud.config import ConfigVersionCrud from app.crud.credentials import get_provider_credential from app.crud.jobs import JobCrud -from app.crud.llm import create_llm_call, update_llm_call_response -from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, Job +from app.crud.llm import create_llm_call, serialize_input, update_llm_call_response +from app.crud.llm_chain import create_llm_chain, update_llm_chain_status +from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest from app.models.llm.request import ( + ChainStatus, ConfigBlob, - LLMCallConfig, KaapiCompletionConfig, + LLMCallConfig, + QueryParams, TextInput, ) from app.models.llm.response import TextOutput +from app.services.llm.chain.types import BlockResult from app.services.llm.guardrails import ( list_validators_config, run_guardrails_validation, ) -from app.services.llm.providers.registry import get_llm_provider +from app.services.llm.input_resolver import cleanup_temp_file, resolve_input from app.services.llm.mappers import transform_kaapi_config_to_native -from app.services.llm.input_resolver import resolve_input, cleanup_temp_file - +from app.services.llm.providers.registry import get_llm_provider from app.utils import APIResponse, send_callback logger = logging.getLogger(__name__) @@ -75,6 +78,49 @@ def start_job( return job.id +def start_chain_job( + db: Session, request: LLMChainRequest, project_id: int, organization_id: int +) -> UUID: + """Create an LLM Chain job and schedule Celery task.""" + trace_id = correlation_id.get() or "N/A" + job_crud = JobCrud(session=db) + job = job_crud.create(job_type=JobType.LLM_CHAIN, trace_id=trace_id) + + # Explicitly flush to ensure job is persisted before Celery task starts + db.flush() + db.commit() + + logger.info( + f"[start_chain_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}" + ) + + try: + task_id = start_high_priority_job( + function_path="app.services.llm.jobs.execute_chain_job", + project_id=project_id, + job_id=str(job.id), + trace_id=trace_id, + request_data=request.model_dump(mode="json"), + organization_id=organization_id, + ) + except Exception as e: + logger.error( + f"[start_chain_job] Error starting Celery task: {str(e)} | job_id={job.id}, project_id={project_id}", + exc_info=True, + ) + job_update = JobUpdate(status=JobStatus.FAILED, error_message=str(e)) + job_crud.update(job_id=job.id, job_update=job_update) + raise HTTPException( + status_code=500, + detail="Internal server error while executing LLM chain job", + ) + + logger.info( + f"[start_chain_job] Job scheduled for LLM chain job | job_id={job.id}, project_id={project_id}, task_id={task_id}" + ) + return job.id + + def handle_job_error( job_id: UUID, callback_url: str | None, @@ -136,226 +182,225 @@ def resolve_config_blob( return None, "Unexpected error occurred while parsing stored configuration" -def execute_job( - request_data: dict, +def apply_input_guardrails( + *, + config_blob: ConfigBlob | None, + query: QueryParams, + job_id: UUID, project_id: int, organization_id: int, - job_id: str, - task_id: str, - task_instance, -) -> dict: - """Celery task to process an LLM request asynchronously. +) -> tuple[QueryParams, str | None]: + """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.""" + if not config_blob or not config_blob.input_guardrails: + return query, None + + if not isinstance(query.input, TextInput): + logger.info( + f"[apply_input_guardrails] Skipping for non-text input. " + f"job_id={job_id}, " + f"input_type={getattr(query.input, 'type', type(query.input).__name__)}" + ) + return query, None - Returns: - dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure - """ + input_guardrails, _ = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=config_blob.input_guardrails, + output_validator_configs=None, + ) - request = LLMCallRequest(**request_data) - job_id: UUID = UUID(job_id) + if not input_guardrails: + return query, None - config = request.config - callback_response = None - config_blob: ConfigBlob | None = None - input_guardrails: list[dict] = [] - output_guardrails: list[dict] = [] - llm_call_id: UUID | None = None # Track the LLM call record + safe = run_guardrails_validation( + query.input.content.value, + input_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) logger.info( - f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}, " + f"[apply_input_guardrails] Validation result | success={safe['success']}, job_id={job_id}" ) - try: - with Session(engine) as session: - # Update job status to PROCESSING - job_crud = JobCrud(session=session) - logger.info(f"[execute_job] Attempting to fetch job | job_id={job_id}") - job = session.get(Job, job_id) - if not job: - # Log all jobs to see what's in the database - from sqlmodel import select - - all_jobs = session.exec( - select(Job).order_by(Job.created_at.desc()).limit(5) - ).all() - logger.error( - f"[execute_job] Job not found! | job_id={job_id} | " - f"Recent jobs in DB: {[(j.id, j.status) for j in all_jobs]}" - ) - else: - logger.info( - f"[execute_job] Found job | job_id={job_id}, status={job.status}" - ) + if safe.get("bypassed"): + logger.info( + f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return query, None - job_crud.update( - job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) - ) + if safe["success"]: + query.input.content.value = safe["data"]["safe_text"] + return query, None - # if stored config, fetch blob from DB - if config.is_stored_config: - config_crud = ConfigVersionCrud( - session=session, project_id=project_id, config_id=config.id - ) + return query, safe["error"] - # blob is dynamic, need to resolve to ConfigBlob format - config_blob, error = resolve_config_blob(config_crud, config) - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) +def apply_output_guardrails( + *, + config_blob: ConfigBlob | None, + result: BlockResult, + job_id: UUID, + project_id: int, + organization_id: int, +) -> tuple[BlockResult, str | None]: + """Apply output guardrails from a config_blob. Shared by /llm/call and /llm/chain. - else: - config_blob = config.blob + Returns (modified_result, None) on success, or (result, error_string) on failure. + """ + if not config_blob or not config_blob.output_guardrails: + return result, None + + if not isinstance(result.response.response.output, TextOutput): + logger.info( + f"[apply_output_guardrails] Skipping for non-text output. " + f"job_id={job_id}, " + f"output_type={getattr(result.response.response.output, 'type', type(result.response.response.output).__name__)}" + ) + return result, None - if config_blob is not None: - if config_blob.input_guardrails or config_blob.output_guardrails: - input_guardrails, output_guardrails = list_validators_config( - organization_id=organization_id, - project_id=project_id, - input_validator_configs=config_blob.input_guardrails, - output_validator_configs=config_blob.output_guardrails, - ) + _, output_guardrails = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=None, + output_validator_configs=config_blob.output_guardrails, + ) - if input_guardrails: - if not isinstance(request.query.input, TextInput): - logger.info( - "[execute_job] Skipping input guardrails for non-text input. " - f"job_id={job_id}, input_type={getattr(request.query.input, 'type', type(request.query.input).__name__)}" - ) - else: - safe_input = run_guardrails_validation( - request.query.input.content.value, - input_guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) + if not output_guardrails: + return result, None + + output_text = result.response.response.output.content.value + safe = run_guardrails_validation( + output_text, + output_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) - logger.info( - f"[execute_job] Input guardrail validation | success={safe_input['success']}." - ) + logger.info( + f"[apply_output_guardrails] Validation result | success={safe['success']}, job_id={job_id}" + ) - if safe_input.get("bypassed"): - logger.info( - "[execute_job] Guardrails bypassed (service unavailable)" - ) + if safe.get("bypassed"): + logger.info( + f"[apply_output_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return result, None - elif safe_input["success"]: - request.query.input.content.value = safe_input["data"][ - "safe_text" - ] - else: - # Update the text value with error message - request.query.input.content.value = safe_input["error"] - - callback_response = APIResponse.failure_response( - error=safe_input["error"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - user_sent_config_provider = "" + if safe["success"]: + result.response.response.output.content.value = safe["data"]["safe_text"] + if safe["data"].get("rephrase_needed"): + return result, result.response.response.output.content.value + return result, None - try: - # Transform Kaapi config to native config if needed (before getting provider) - completion_config = config_blob.completion + return result, safe["error"] - original_provider = ( - config_blob.completion.provider - ) # openai, google or prefixed - if isinstance(completion_config, KaapiCompletionConfig): - completion_config, warnings = transform_kaapi_config_to_native( - completion_config - ) +def execute_llm_call( + *, + config: LLMCallConfig, + query: QueryParams, + job_id: UUID, + project_id: int, + organization_id: int, + request_metadata: dict | None, + langfuse_credentials: dict | None, + include_provider_raw_response: bool = False, + chain_id: UUID | None = None, +) -> BlockResult: + """Execute a single LLM call. Shared by /llm/call and /llm/chain. + + Returns BlockResult with response + usage on success, or error on failure. + """ - if request.request_metadata is None: - request.request_metadata = {} - request.request_metadata.setdefault("warnings", []).extend(warnings) - else: - pass - except Exception as e: - callback_response = APIResponse.failure_response( - error=f"Error processing configuration: {str(e)}", - metadata=request.request_metadata, + config_blob: ConfigBlob | None = None + llm_call_id: UUID | None = None + + try: + with Session(engine) as session: + if config.is_stored_config: + config_crud = ConfigVersionCrud( + session=session, project_id=project_id, config_id=config.id ) - return handle_job_error(job_id, request.callback_url, callback_response) + config_blob, error = resolve_config_blob(config_crud, config) + if error: + return BlockResult(error=error) + else: + config_blob = config.blob - # Create LLM call record before execution - try: - # Rebuild ConfigBlob with transformed native config - resolved_config_blob = ConfigBlob( - completion=completion_config, - input_guardrails=config_blob.input_guardrails, - output_guardrails=config_blob.output_guardrails, + if config_blob.prompt_template and isinstance(query.input, TextInput): + template = config_blob.prompt_template.template + interpolated = template.replace("{{input}}", query.input.content.value) + query.input.content.value = interpolated + + completion_config = config_blob.completion + original_provider = completion_config.provider + + if isinstance(completion_config, KaapiCompletionConfig): + completion_config, warnings = transform_kaapi_config_to_native( + completion_config ) + if request_metadata is None: + request_metadata = {} + request_metadata.setdefault("warnings", []).extend(warnings) + + resolved_config_blob = ConfigBlob( + completion=completion_config, + prompt_template=config_blob.prompt_template, + input_guardrails=config_blob.input_guardrails, + output_guardrails=config_blob.output_guardrails, + ) + try: + temp_request = LLMCallRequest( + query=query, + config=config, + request_metadata=request_metadata, + ) llm_call = create_llm_call( session, - request=request, + request=temp_request, job_id=job_id, project_id=project_id, organization_id=organization_id, resolved_config=resolved_config_blob, original_provider=original_provider, + chain_id=chain_id, ) llm_call_id = llm_call.id logger.info( - f"[execute_job] Created LLM call record | llm_call_id={llm_call_id}, job_id={job_id}" + f"[execute_llm_call] Created LLM call record | " + f"llm_call_id={llm_call_id}, job_id={job_id}" ) except Exception as e: logger.error( - f"[execute_job] Failed to create LLM call record: {str(e)} | job_id={job_id}", + f"[execute_llm_call] Failed to create LLM call record: {e} | job_id={job_id}", exc_info=True, ) - callback_response = APIResponse.failure_response( - error=f"Failed to create LLM call record: {str(e)}", - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) + return BlockResult(error=f"Failed to create LLM call record: {str(e)}") try: provider_instance = get_llm_provider( session=session, - provider_type=completion_config.provider, # Now always native provider type i.e openai-native, google-native regardless + provider_type=completion_config.provider, project_id=project_id, organization_id=organization_id, ) except ValueError as ve: - callback_response = APIResponse.failure_response( - error=str(ve), - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - - langfuse_credentials = get_provider_credential( - session=session, - org_id=organization_id, - project_id=project_id, - provider="langfuse", - ) + return BlockResult(error=str(ve), llm_call_id=llm_call_id) - # Extract conversation_id for langfuse session grouping conversation_id = None - if request.query.conversation and request.query.conversation.id: - conversation_id = request.query.conversation.id + if query.conversation and query.conversation.id: + conversation_id = query.conversation.id - # Resolve input (handles text, audio_base64, audio_url) - resolved_input, resolve_error = resolve_input(request.query.input) + resolved_input, resolve_error = resolve_input(query.input) if resolve_error: - callback_response = APIResponse.failure_response( - error=resolve_error, - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) + return BlockResult(error=resolve_error, llm_call_id=llm_call_id) - # Apply Langfuse observability decorator to provider execute method decorated_execute = observe_llm_execution( credentials=langfuse_credentials, session_id=conversation_id, @@ -364,80 +409,16 @@ def execute_job( try: response, error = decorated_execute( completion_config=completion_config, - query=request.query, + query=query, resolved_input=resolved_input, - include_provider_raw_response=request.include_provider_raw_response, + include_provider_raw_response=include_provider_raw_response, ) finally: - # Clean up temp files for audio inputs - if resolved_input and resolved_input != request.query.input: + if resolved_input and resolved_input != query.input: cleanup_temp_file(resolved_input) if response: - if output_guardrails: - if not isinstance(response.response.output, TextOutput): - logger.info( - "[execute_job] Skipping output guardrails for non-text output. " - f"job_id={job_id}, output_type={getattr(response.response.output, 'type', type(response.response.output).__name__)}" - ) - else: - output_text = response.response.output.content.value - safe_output = run_guardrails_validation( - output_text, - output_guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) - - logger.info( - f"[execute_job] Output guardrail validation | success={safe_output['success']}." - ) - - if safe_output.get("bypassed"): - logger.info( - "[execute_job] Guardrails bypassed (service unavailable)" - ) - - elif safe_output["success"]: - response.response.output.content.value = safe_output["data"][ - "safe_text" - ] - - if safe_output["data"]["rephrase_needed"] == True: - callback_response = APIResponse.failure_response( - error=request.query.input, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - - else: - response.response.output.content.value = safe_output["error"] - - callback_response = APIResponse.failure_response( - error=safe_output["error"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - - callback_response = APIResponse.success_response( - data=response, metadata=request.request_metadata - ) - if request.callback_url: - send_callback( - callback_url=request.callback_url, - data=callback_response.model_dump(), - ) - with Session(engine) as session: - job_crud = JobCrud(session=session) - - # Update LLM call record with response data if llm_call_id: try: update_llm_call_response( @@ -448,34 +429,154 @@ def execute_job( usage=response.usage.model_dump(), conversation_id=response.response.conversation_id, ) - logger.info( - f"[execute_job] Updated LLM call record | llm_call_id={llm_call_id}" - ) except Exception as e: logger.error( - f"[execute_job] Failed to update LLM call record: {str(e)} | llm_call_id={llm_call_id}", + f"[execute_llm_call] Failed to update LLM call record: {e} | " + f"llm_call_id={llm_call_id}", exc_info=True, ) - # Don't fail the job if updating the record fails - job_crud.update( + return BlockResult( + response=response, + llm_call_id=llm_call_id, + usage=response.usage, + ) + + return BlockResult( + error=error or "Unknown error occurred", + llm_call_id=llm_call_id, + ) + + except Exception as e: + logger.error( + f"[execute_llm_call] Unexpected error: {e} | job_id={job_id}", + exc_info=True, + ) + return BlockResult( + error="Unexpected error occurred", + llm_call_id=llm_call_id, + ) + + +def execute_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task to process an LLM request asynchronously. + + Uses centralized functions: apply_input_guardrails, apply_output_guardrails, execute_llm_call. + """ + request = LLMCallRequest(**request_data) + job_id: UUID = UUID(job_id) + config = request.config + config_blob: ConfigBlob | None = None + + logger.info( + f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}" + ) + + try: + with Session(engine) as session: + job_crud = JobCrud(session=session) + job_crud.update( + job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) + ) + + if config.is_stored_config: + config_crud = ConfigVersionCrud( + session=session, project_id=project_id, config_id=config.id + ) + config_blob, error = resolve_config_blob(config_crud, config) + if error: + callback_response = APIResponse.failure_response( + error=error, + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) + else: + config_blob = config.blob + + request.query, input_error = apply_input_guardrails( + config_blob=config_blob, + query=request.query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if input_error: + callback_response = APIResponse.failure_response( + error=input_error, + metadata=request.request_metadata, + ) + return handle_job_error(job_id, request.callback_url, callback_response) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + result = execute_llm_call( + config=request.config, + query=request.query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + request_metadata=request.request_metadata, + langfuse_credentials=langfuse_credentials, + include_provider_raw_response=request.include_provider_raw_response, + ) + + if result.success: + result, output_error = apply_output_guardrails( + config_blob=config_blob, + result=result, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if output_error: + callback_response = APIResponse.failure_response( + error=output_error, + metadata=request.request_metadata, + ) + return handle_job_error(job_id, request.callback_url, callback_response) + + callback_response = APIResponse.success_response( + data=result.response, metadata=request.request_metadata + ) + if request.callback_url: + send_callback( + callback_url=request.callback_url, + data=callback_response.model_dump(), + ) + + with Session(engine) as session: + JobCrud(session=session).update( job_id=job_id, job_update=JobUpdate(status=JobStatus.SUCCESS) ) logger.info( f"[execute_job] Successfully completed LLM job | job_id={job_id}, " - f"provider_response_id={response.response.provider_response_id}, tokens={response.usage.total_tokens}" + f"tokens={result.usage.total_tokens}" ) return callback_response.model_dump() callback_response = APIResponse.failure_response( - error=error or "Unknown error occurred", + error=result.error or "Unknown error occurred", metadata=request.request_metadata, ) return handle_job_error(job_id, request.callback_url, callback_response) except Exception as e: callback_response = APIResponse.failure_response( - error=f"Unexpected error occurred", + error="Unexpected error occurred", metadata=request.request_metadata, ) logger.error( @@ -483,3 +584,108 @@ def execute_job( exc_info=True, ) return handle_job_error(job_id, request.callback_url, callback_response) + + +def execute_chain_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task entry point for LLM chain execution.""" + # imports to avoid circular dependency: + from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain + from app.services.llm.chain.executor import ChainExecutor + + request = LLMChainRequest(**request_data) + job_uuid = UUID(job_id) + chain_uuid = None + + logger.info( + f"[execute_chain_job] Starting chain execution | " + f"job_id={job_uuid}, total_blocks={len(request.blocks)}" + ) + + try: + with Session(engine) as session: + chain_record = create_llm_chain( + session, + job_id=job_uuid, + project_id=project_id, + organization_id=organization_id, + total_blocks=len(request.blocks), + input=serialize_input(request.query.input), + configs=[block.model_dump(mode="json") for block in request.blocks], + ) + chain_uuid = chain_record.id + + logger.info( + f"[execute_chain_job] Created chain record | " + f"chain_id={chain_uuid}, job_id={job_uuid}" + ) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + context = ChainContext( + job_id=job_uuid, + chain_id=chain_uuid, + project_id=project_id, + organization_id=organization_id, + langfuse_credentials=langfuse_credentials, + request_metadata=request.request_metadata, + total_blocks=len(request.blocks), + callback_url=str(request.callback_url) if request.callback_url else None, + intermediate_callback_flags=[ + block.intermediate_callback for block in request.blocks + ], + ) + + blocks = [ + ChainBlock( + config=block.config, + index=i, + context=context, + include_provider_raw_response=block.include_provider_raw_response, + ) + for i, block in enumerate(request.blocks) + ] + + chain = LLMChain(blocks) + + executor = ChainExecutor(chain=chain, context=context, request=request) + return executor.run() + + except Exception as e: + logger.error( + f"[execute_chain_job] Failed: {e} | job_id={job_uuid}", + exc_info=True, + ) + + if chain_uuid: + try: + with Session(engine) as session: + update_llm_chain_status( + session, + chain_id=chain_uuid, + status=ChainStatus.FAILED, + error=str(e), + ) + except Exception: + logger.error( + f"[execute_chain_job] Failed to update chain status: {e} | " + f"chain_id={chain_uuid}", + exc_info=True, + ) + + callback_response = APIResponse.failure_response( + error="Unexpected error occurred", + metadata=request.request_metadata, + ) + return handle_job_error(job_uuid, request.callback_url, callback_response) From 6451bb03c45017aa71c78a389384adff84e427fa Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sat, 21 Feb 2026 15:54:59 +0530 Subject: [PATCH 02/11] LLM Chain: Add documentation and update endpoint description for chain execution --- backend/app/api/docs/llm/llm_chain.md | 60 +++++++++++++++++++++++++++ backend/app/api/routes/llm_chain.py | 2 +- backend/app/services/llm/jobs.py | 11 +++-- 3 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 backend/app/api/docs/llm/llm_chain.md diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md new file mode 100644 index 000000000..d6c17893c --- /dev/null +++ b/backend/app/api/docs/llm/llm_chain.md @@ -0,0 +1,60 @@ +Execute a chain of LLM calls sequentially, where each block's output becomes the next block's input. + +This endpoint initiates an asynchronous LLM chain job. The request is queued +for processing, and results are delivered via the callback URL when complete. + +### Key Parameters + +**`query`** (required) - Initial query input for the first block in the chain: +- `input` (required, string, min 1 char): User question/prompt/query +- `conversation` (optional, object): Conversation configuration + - `id` (optional, string): Existing conversation ID to continue + - `auto_create` (optional, boolean, default false): Create new conversation if no ID provided + - **Note**: Cannot specify both `id` and `auto_create=true` + + +**`blocks`** (required, array, min 1 block) - Ordered list of blocks to execute sequentially. Each block contains: + +- `config` (required) - Configuration for this block's LLM call (just choose one mode): + + - **Mode 1: Stored Configuration** + - `id` (UUID): Configuration ID + - `version` (integer >= 1): Version number + - **Both required together** + - **Note**: When using stored configuration, do not include the `blob` field in the request body + + - **Mode 2: Ad-hoc Configuration** + - `blob` (object): Complete configuration object + - `completion` (required, object): Completion configuration + - `provider` (required, string): Provider type - either `"openai"` (Kaapi abstraction) or `"openai-native"` (pass-through) + - `params` (required, object): Parameters structure depends on provider type (see schema for detailed structure) + - `prompt_template` (optional, object): Template for text interpolation + - `template` (required, string): Template string with `{{input}}` placeholder — replaced with the block's input before execution + - **Note** + - When using ad-hoc configuration, do not include `id` and `version` fields + - When using the Kaapi abstraction, parameters that are not supported by the selected provider or model are automatically suppressed. If any parameters are ignored, a list of warnings is included in the metadata.warnings. + - **Recommendation**: Use stored configs (Mode 1) for production; use ad-hoc configs only for testing/validation + - **Schema**: Check the API schema or examples below for the complete parameter structure for each provider type + +- `include_provider_raw_response` (optional, boolean, default false): + - When true, includes the unmodified raw response from the LLM provider for this block + +- `intermediate_callback` (optional, boolean, default false): + - When true, sends an intermediate callback after this block completes with the block's response, usage, and position in the chain + +**`callback_url`** (optional, HTTPS URL): +- Webhook endpoint to receive the final response and intermediate callbacks +- Must be a valid HTTPS URL +- If not provided, response is only accessible through job status + +**`request_metadata`** (optional, object): +- Custom JSON metadata +- Passed through unchanged in the response + +### Note +- Input guardrails from the first block's config are applied before chain execution starts +- Output guardrails from the last block's config are applied after all blocks complete +- If any block fails, the chain stops immediately — no subsequent blocks are executed +- `warnings` list is automatically added in response metadata when using Kaapi configs if any parameters are suppressed or adjusted (e.g., temperature on reasoning models) + +--- diff --git a/backend/app/api/routes/llm_chain.py b/backend/app/api/routes/llm_chain.py index 0634c2038..92a3cdb4d 100644 --- a/backend/app/api/routes/llm_chain.py +++ b/backend/app/api/routes/llm_chain.py @@ -31,7 +31,7 @@ def llm_callback_notification(body: APIResponse[LLMChainResponse]): @router.post( "/llm/chain", - description=load_description("llm/llm_call.md"), + description=load_description("llm/llm_chain.md"), response_model=APIResponse[Message], callbacks=llm_callback_router.routes, dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index cd71e5bfa..a68ab6426 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -293,8 +293,6 @@ def apply_output_guardrails( if safe["success"]: result.response.response.output.content.value = safe["data"]["safe_text"] - if safe["data"].get("rephrase_needed"): - return result, result.response.response.output.content.value return result, None return result, safe["error"] @@ -468,7 +466,8 @@ def execute_job( ) -> dict: """Celery task to process an LLM request asynchronously. - Uses centralized functions: apply_input_guardrails, apply_output_guardrails, execute_llm_call. + Returns: + dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure """ request = LLMCallRequest(**request_data) job_id: UUID = UUID(job_id) @@ -594,7 +593,11 @@ def execute_chain_job( task_id: str, task_instance, ) -> dict: - """Celery task entry point for LLM chain execution.""" + """Celery task to process an LLM Chain request asynchronously. + + Returns: + dict: Serialized APIResponse[LLMChainResponse] on success, APIResponse[None] on failure + """ # imports to avoid circular dependency: from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain from app.services.llm.chain.executor import ChainExecutor From c9f94e257552c3513d17c2f15ca880a9b4dd60e0 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sat, 21 Feb 2026 16:09:58 +0530 Subject: [PATCH 03/11] LLM Chain: Move guardrails into execute_llm_call for per-block support and eliminate code duplication --- backend/app/services/llm/chain/executor.py | 69 +-------------------- backend/app/services/llm/jobs.py | 71 +++++++--------------- 2 files changed, 24 insertions(+), 116 deletions(-) diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py index 78808d84c..25208aad6 100644 --- a/backend/app/services/llm/chain/executor.py +++ b/backend/app/services/llm/chain/executor.py @@ -3,23 +3,16 @@ from sqlmodel import Session from app.core.db import engine -from app.crud.config import ConfigVersionCrud from app.crud.jobs import JobCrud from app.crud.llm_chain import update_llm_chain_status from app.models import JobStatus, JobUpdate from app.models.llm.request import ( ChainStatus, - ConfigBlob, LLMChainRequest, ) from app.models.llm.response import LLMChainResponse from app.services.llm.chain.chain import ChainContext, LLMChain from app.services.llm.chain.types import BlockResult -from app.services.llm.jobs import ( - apply_input_guardrails, - apply_output_guardrails, - resolve_config_blob, -) from app.utils import APIResponse, send_callback logger = logging.getLogger(__name__) @@ -44,73 +37,13 @@ def run(self) -> dict: try: self._setup() - first_config_blob, resolve_error = self._resolve_block_config_blob(0) - if resolve_error: - return self._handle_error(resolve_error) - - query, error = apply_input_guardrails( - config_blob=first_config_blob, - query=self._request.query, - job_id=self._context.job_id, - project_id=self._context.project_id, - organization_id=self._context.organization_id, - ) - if error: - return self._handle_error(error) - - result = self._chain.execute(query) - - if result.success: - last_config_blob, resolve_error = self._resolve_block_config_blob( - len(self._request.blocks) - 1 - ) - if resolve_error: - return self._handle_error(resolve_error) - - result, error = apply_output_guardrails( - config_blob=last_config_blob, - result=result, - job_id=self._context.job_id, - project_id=self._context.project_id, - organization_id=self._context.organization_id, - ) - if error: - return self._handle_error(error) + result = self._chain.execute(self._request.query) return self._teardown(result) except Exception as e: return self._handle_unexpected_error(e) - def _resolve_block_config_blob( - self, block_index: int - ) -> tuple[ConfigBlob | None, str | None]: - """Resolve a block's config to its ConfigBlob. - - Uses is_stored_config property (same pattern as execute_job in jobs.py): - - Stored config (is_stored_config=True): fetch from DB via resolve_config_blob() - - Ad-hoc config (blob provided): return blob directly - - Returns: - (config_blob, error): ConfigBlob on success, or error string on failure - """ - block = self._request.blocks[block_index] - config = block.config - - if not config.is_stored_config: - return config.blob, None - - with Session(engine) as session: - config_crud = ConfigVersionCrud( - session=session, - project_id=self._context.project_id, - config_id=config.id, - ) - config_blob, error = resolve_config_blob(config_crud, config) - if error: - return None, error - return config_blob, None - def _setup(self) -> None: with Session(engine) as session: JobCrud(session).update( diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index a68ab6426..196bdd60b 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -335,6 +335,16 @@ def execute_llm_call( interpolated = template.replace("{{input}}", query.input.content.value) query.input.content.value = interpolated + query, input_error = apply_input_guardrails( + config_blob=config_blob, + query=query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if input_error: + return BlockResult(error=input_error) + completion_config = config_blob.completion original_provider = completion_config.provider @@ -433,13 +443,24 @@ def execute_llm_call( f"llm_call_id={llm_call_id}", exc_info=True, ) - - return BlockResult( + result = BlockResult( response=response, llm_call_id=llm_call_id, usage=response.usage, ) + result, output_error = apply_output_guardrails( + config_blob=config_blob, + result=result, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if output_error: + return BlockResult(error=output_error, llm_call_id=llm_call_id) + + return result + return BlockResult( error=error or "Unknown error occurred", llm_call_id=llm_call_id, @@ -471,8 +492,6 @@ def execute_job( """ request = LLMCallRequest(**request_data) job_id: UUID = UUID(job_id) - config = request.config - config_blob: ConfigBlob | None = None logger.info( f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}" @@ -485,36 +504,6 @@ def execute_job( job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) ) - if config.is_stored_config: - config_crud = ConfigVersionCrud( - session=session, project_id=project_id, config_id=config.id - ) - config_blob, error = resolve_config_blob(config_crud, config) - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - else: - config_blob = config.blob - - request.query, input_error = apply_input_guardrails( - config_blob=config_blob, - query=request.query, - job_id=job_id, - project_id=project_id, - organization_id=organization_id, - ) - if input_error: - callback_response = APIResponse.failure_response( - error=input_error, - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - langfuse_credentials = get_provider_credential( session=session, org_id=organization_id, @@ -534,20 +523,6 @@ def execute_job( ) if result.success: - result, output_error = apply_output_guardrails( - config_blob=config_blob, - result=result, - job_id=job_id, - project_id=project_id, - organization_id=organization_id, - ) - if output_error: - callback_response = APIResponse.failure_response( - error=output_error, - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - callback_response = APIResponse.success_response( data=result.response, metadata=request.request_metadata ) From baaac95920072e6b1b8a761827104705510ce034 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 1 Mar 2026 20:29:49 +0530 Subject: [PATCH 04/11] prettify format --- backend/app/services/llm/jobs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 6888fe99a..4142f690d 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -493,10 +493,7 @@ def execute_llm_call( include_provider_raw_response=include_provider_raw_response, ) except ValueError as ve: - return BlockResult( - error=str(ve), - llm_call_id=llm_call_id - ) + return BlockResult(error=str(ve), llm_call_id=llm_call_id) if response: with Session(engine) as session: From 5177bfb6b0100ee4638ea99e152a082093c93fe4 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:24:05 +0530 Subject: [PATCH 05/11] refactor: update STTLLMParams to allow optional instructions and improve callback logic in ChainContext --- backend/app/models/llm/request.py | 2 +- backend/app/services/llm/chain/chain.py | 1 + backend/app/services/llm/jobs.py | 7 ++----- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 58c4529a6..ed7f08b55 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -44,7 +44,7 @@ class TextLLMParams(SQLModel): class STTLLMParams(SQLModel): model: str - instructions: str + instructions: str | None = None input_language: str | None = None output_language: str | None = None response_format: Literal["text"] | None = Field( diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py index 390247d8d..98457fff6 100644 --- a/backend/app/services/llm/chain/chain.py +++ b/backend/app/services/llm/chain/chain.py @@ -70,6 +70,7 @@ def on_block_completed(self, block_index: int, result: BlockResult) -> None: block_index < len(self.intermediate_callback_flags) and self.intermediate_callback_flags[block_index] and self.callback_url + and block_index < self.total_blocks - 1 ): self._send_intermediate_callback(block_index, result) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 4142f690d..3efdbbc27 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -474,10 +474,7 @@ def execute_llm_call( if query.conversation and query.conversation.id: conversation_id = query.conversation.id - resolved_input, resolve_error = resolve_input(query.input) - if resolve_error: - return BlockResult(error=resolve_error, llm_call_id=llm_call_id) - + # Apply Langfuse observability decorator to provider execute method decorated_execute = observe_llm_execution( credentials=langfuse_credentials, session_id=conversation_id, @@ -485,7 +482,7 @@ def execute_llm_call( # Resolve input and execute LLM (context manager handles cleanup) try: - with resolved_input_context(query) as resolved_input: + with resolved_input_context(query.input) as resolved_input: response, error = decorated_execute( completion_config=completion_config, query=query, From 2fb81b1b49d84941d3d4d2b94e4c4abfb2eb2308 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:36:32 +0530 Subject: [PATCH 06/11] feat: add metadata to BlockResult and update job execution to use result metadata --- backend/app/services/llm/chain/types.py | 1 + backend/app/services/llm/jobs.py | 3 ++- backend/app/tests/services/llm/test_jobs.py | 27 ++++++++++++--------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/backend/app/services/llm/chain/types.py b/backend/app/services/llm/chain/types.py index 69ab3d02f..7fa0f39d8 100644 --- a/backend/app/services/llm/chain/types.py +++ b/backend/app/services/llm/chain/types.py @@ -12,6 +12,7 @@ class BlockResult: llm_call_id: UUID | None = None usage: Usage | None = None error: str | None = None + metadata: dict | None = None @property def success(self) -> bool: diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 3efdbbc27..62d4325cc 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -514,6 +514,7 @@ def execute_llm_call( response=response, llm_call_id=llm_call_id, usage=response.usage, + metadata=request_metadata, ) result, output_error = apply_output_guardrails( @@ -592,7 +593,7 @@ def execute_job( if result.success: callback_response = APIResponse.success_response( - data=result.response, metadata=request.request_metadata + data=result.response, metadata=result.metadata ) if callback_url_str: send_callback( diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 60456e00b..f448be9b2 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -367,7 +367,7 @@ def test_exception_during_execution( result = self._execute_job(job_for_execution, db, request_data) assert not result["success"] - assert "Unexpected error during LLM execution" in result["error"] + assert "Unexpected error occurred" in result["error"] def test_exception_during_provider_retrieval( self, db, job_env, job_for_execution, request_data @@ -1108,16 +1108,21 @@ def test_execute_job_fetches_validator_configs_from_blob_refs( result = self._execute_job(job_for_execution, db, request_data) assert result["success"] - mock_fetch_configs.assert_called_once() - _, kwargs = mock_fetch_configs.call_args - input_validator_configs = kwargs["input_validator_configs"] - output_validator_configs = kwargs["output_validator_configs"] - assert [v.validator_config_id for v in input_validator_configs] == [ - UUID(VALIDATOR_CONFIG_ID_1) - ] - assert [v.validator_config_id for v in output_validator_configs] == [ - UUID(VALIDATOR_CONFIG_ID_2) - ] + assert mock_fetch_configs.call_count == 2 + + # First call: input guardrails + _, input_kwargs = mock_fetch_configs.call_args_list[0] + assert [ + v.validator_config_id for v in input_kwargs["input_validator_configs"] + ] == [UUID(VALIDATOR_CONFIG_ID_1)] + assert input_kwargs["output_validator_configs"] is None + + # Second call: output guardrails + _, output_kwargs = mock_fetch_configs.call_args_list[1] + assert output_kwargs["input_validator_configs"] is None + assert [ + v.validator_config_id for v in output_kwargs["output_validator_configs"] + ] == [UUID(VALIDATOR_CONFIG_ID_2)] def test_execute_job_continues_when_no_validator_configs_resolved( self, db, job_env, job_for_execution From 113488a6bd4739c9ecbae28e3eb41a9ce40e7dce Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:43:21 +0530 Subject: [PATCH 07/11] feat: add tests for LLM chain execution and job handling --- backend/app/tests/crud/test_llm_chain.py | 153 ++++++++ backend/app/tests/services/llm/test_chain.py | 356 ++++++++++++++++++ .../tests/services/llm/test_chain_executor.py | 215 +++++++++++ backend/app/tests/services/llm/test_jobs.py | 206 +++++++++- 4 files changed, 929 insertions(+), 1 deletion(-) create mode 100644 backend/app/tests/crud/test_llm_chain.py create mode 100644 backend/app/tests/services/llm/test_chain.py create mode 100644 backend/app/tests/services/llm/test_chain_executor.py diff --git a/backend/app/tests/crud/test_llm_chain.py b/backend/app/tests/crud/test_llm_chain.py new file mode 100644 index 000000000..84324f86c --- /dev/null +++ b/backend/app/tests/crud/test_llm_chain.py @@ -0,0 +1,153 @@ +import pytest +from uuid import uuid4 + +from sqlmodel import Session + +from app.crud import JobCrud +from app.crud.llm_chain import ( + create_llm_chain, + update_llm_chain_status, + update_llm_chain_block_completed, +) +from app.models import JobType +from app.models.llm.request import ChainStatus +from app.tests.utils.utils import get_project + + +class TestCreateLlmChain: + def test_creates_chain_record(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=3, + input="Test input", + configs=[{"completion": {"provider": "openai-native"}}], + ) + + assert chain.id is not None + assert chain.job_id == job.id + assert chain.project_id == project.id + assert chain.status == ChainStatus.PENDING + assert chain.total_blocks == 3 + assert chain.number_of_blocks_processed == 0 + assert chain.input == "Test input" + assert chain.block_sequences == [] + + +class TestUpdateLlmChainStatus: + @pytest.fixture + def chain(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=2, + input="hello", + configs=[], + ) + return chain + + def test_update_to_running(self, db: Session, chain): + updated = update_llm_chain_status( + db, chain_id=chain.id, status=ChainStatus.RUNNING + ) + + assert updated.status == ChainStatus.RUNNING + assert updated.started_at is not None + + def test_update_to_completed(self, db: Session, chain): + output = {"type": "text", "content": {"value": "result"}} + usage = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30} + + updated = update_llm_chain_status( + db, + chain_id=chain.id, + status=ChainStatus.COMPLETED, + output=output, + total_usage=usage, + ) + + assert updated.status == ChainStatus.COMPLETED + assert updated.output == output + assert updated.total_usage == usage + assert updated.completed_at is not None + + def test_update_to_failed(self, db: Session, chain): + usage = {"input_tokens": 5, "output_tokens": 0, "total_tokens": 5} + + updated = update_llm_chain_status( + db, + chain_id=chain.id, + status=ChainStatus.FAILED, + error="Provider timeout", + total_usage=usage, + ) + + assert updated.status == ChainStatus.FAILED + assert updated.error == "Provider timeout" + assert updated.total_usage == usage + assert updated.completed_at is not None + + def test_raises_for_missing_chain(self, db: Session): + with pytest.raises(ValueError, match="LLM chain not found"): + update_llm_chain_status(db, chain_id=uuid4(), status=ChainStatus.RUNNING) + + +class TestUpdateLlmChainBlockCompleted: + @pytest.fixture + def chain(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=3, + input="hello", + configs=[], + ) + return chain + + def test_appends_llm_call_id(self, db: Session, chain): + call_id = uuid4() + + updated = update_llm_chain_block_completed( + db, chain_id=chain.id, llm_call_id=call_id + ) + + assert str(call_id) in updated.block_sequences + assert updated.number_of_blocks_processed == 1 + + def test_appends_multiple_blocks(self, db: Session, chain): + call_id_1 = uuid4() + call_id_2 = uuid4() + + update_llm_chain_block_completed(db, chain_id=chain.id, llm_call_id=call_id_1) + updated = update_llm_chain_block_completed( + db, chain_id=chain.id, llm_call_id=call_id_2 + ) + + assert len(updated.block_sequences) == 2 + assert updated.number_of_blocks_processed == 2 + + def test_raises_for_missing_chain(self, db: Session): + with pytest.raises(ValueError, match="LLM chain not found"): + update_llm_chain_block_completed(db, chain_id=uuid4(), llm_call_id=uuid4()) diff --git a/backend/app/tests/services/llm/test_chain.py b/backend/app/tests/services/llm/test_chain.py new file mode 100644 index 000000000..d93380d84 --- /dev/null +++ b/backend/app/tests/services/llm/test_chain.py @@ -0,0 +1,356 @@ +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest + +from app.models.llm.request import ( + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + TextInput, + TextContent, + AudioInput, +) +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + Usage, + TextOutput, + TextContent as ResponseTextContent, + AudioOutput, + AudioContent, +) +from app.services.llm.chain.chain import ( + ChainBlock, + ChainContext, + LLMChain, + result_to_query, +) +from app.services.llm.chain.types import BlockResult + + +@pytest.fixture +def context(): + return ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url="https://example.com/callback", + total_blocks=3, + intermediate_callback_flags=[True, True, False], + ) + + +@pytest.fixture +def text_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-1", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=ResponseTextContent(value="Hello world")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + +@pytest.fixture +def audio_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-2", + conversation_id=None, + model="gemini", + provider="google", + output=AudioOutput( + content=AudioContent( + format="base64", + value="audio-data-base64", + mime_type="audio/wav", + ) + ), + ), + usage=Usage(input_tokens=5, output_tokens=15, total_tokens=20), + provider_raw_response=None, + ) + + +def make_config(): + return LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + + +class TestResultToQuery: + def test_text_output_to_query(self, text_response): + result = BlockResult(response=text_response, usage=text_response.usage) + + query = result_to_query(result) + + assert isinstance(query.input, TextInput) + assert query.input.content.value == "Hello world" + + def test_audio_output_to_query(self, audio_response): + result = BlockResult(response=audio_response, usage=audio_response.usage) + + query = result_to_query(result) + + assert isinstance(query.input, AudioInput) + assert query.input.content.value == "audio-data-base64" + + def test_unsupported_output_type_raises(self): + mock_response = MagicMock() + mock_response.response.output.type = "unknown" + mock_response.response.output.__class__ = type("Unknown", (), {}) + result = BlockResult(response=mock_response, usage=MagicMock()) + + with pytest.raises(ValueError, match="Cannot chain output type"): + result_to_query(result) + + +class TestChainContext: + def test_aggregates_usage(self, context): + usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + result = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage, error=None + ) + + with patch("app.services.llm.chain.chain.Session"): + context.on_block_completed(0, result) + + assert context.aggregated_usage.input_tokens == 10 + assert context.aggregated_usage.output_tokens == 20 + assert context.aggregated_usage.total_tokens == 30 + + def test_aggregates_usage_across_blocks(self, context): + usage1 = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + usage2 = Usage(input_tokens=5, output_tokens=15, total_tokens=20) + + result1 = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage1, error=None + ) + result2 = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage2, error=None + ) + + with patch("app.services.llm.chain.chain.Session"): + context.on_block_completed(0, result1) + context.on_block_completed(1, result2) + + assert context.aggregated_usage.input_tokens == 15 + assert context.aggregated_usage.total_tokens == 50 + + def test_updates_db_on_success(self, context): + llm_call_id = uuid4() + result = BlockResult( + response=MagicMock(), llm_call_id=llm_call_id, usage=MagicMock(), error=None + ) + + with patch("app.services.llm.chain.chain.Session") as mock_session, patch( + "app.services.llm.chain.chain.update_llm_chain_block_completed" + ) as mock_update: + mock_session.return_value.__enter__.return_value = MagicMock() + context.on_block_completed(0, result) + + mock_update.assert_called_once_with( + mock_session.return_value.__enter__.return_value, + chain_id=context.chain_id, + llm_call_id=llm_call_id, + ) + + def test_sends_intermediate_callback(self, context, text_response): + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch("app.services.llm.chain.chain.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + context.on_block_completed(0, result) + + mock_callback.assert_called_once() + call_kwargs = mock_callback.call_args[1] + assert call_kwargs["callback_url"] == "https://example.com/callback" + + def test_skips_intermediate_callback_for_last_block(self, context, text_response): + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch("app.services.llm.chain.chain.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + # Block index 2 = last block (total_blocks=3) + context.on_block_completed(2, result) + + mock_callback.assert_not_called() + + def test_skips_intermediate_callback_when_flag_false(self, context, text_response): + context.intermediate_callback_flags = [False, True, False] + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch("app.services.llm.chain.chain.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + context.on_block_completed(0, result) + + mock_callback.assert_not_called() + + def test_skips_db_update_on_error(self, context): + result = BlockResult(error="Block failed", usage=MagicMock()) + + with patch( + "app.services.llm.chain.chain.update_llm_chain_block_completed" + ) as mock_update: + context.on_block_completed(0, result) + mock_update.assert_not_called() + + def test_intermediate_callback_exception_is_swallowed(self, context, text_response): + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch( + "app.services.llm.chain.chain.send_callback", + side_effect=Exception("Connection error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + # Should not raise + context.on_block_completed(0, result) + + +class TestChainBlock: + def test_execute_single_block(self, context, text_response): + query = QueryParams(input="test input") + config = make_config() + block = ChainBlock(config=config, index=0, context=context) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = block.execute(query) + + assert result.success + mock_execute.assert_called_once() + + def test_execute_chains_to_next_block(self, context, text_response): + query = QueryParams(input="test input") + config = make_config() + block1 = ChainBlock(config=config, index=0, context=context) + block2 = ChainBlock(config=config, index=1, context=context) + block1.link(block2) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = block1.execute(query) + + assert mock_execute.call_count == 2 + + def test_execute_stops_on_failure(self, context): + query = QueryParams(input="test input") + config = make_config() + block1 = ChainBlock(config=config, index=0, context=context) + block2 = ChainBlock(config=config, index=1, context=context) + block1.link(block2) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult(error="Provider error") + + result = block1.execute(query) + + assert not result.success + assert result.error == "Provider error" + mock_execute.assert_called_once() + + +class TestLLMChain: + def test_execute_empty_chain(self): + chain = LLMChain([]) + query = QueryParams(input="test") + + result = chain.execute(query) + + assert not result.success + assert result.error == "Chain has no blocks" + + def test_execute_single_block_chain(self, context, text_response): + config = make_config() + block = ChainBlock(config=config, index=0, context=context) + chain = LLMChain([block]) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = chain.execute(QueryParams(input="hello")) + + assert result.success + mock_execute.assert_called_once() + + def test_execute_multi_block_chain(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] + chain = LLMChain(blocks) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = chain.execute(QueryParams(input="hello")) + + assert result.success + assert mock_execute.call_count == 3 diff --git a/backend/app/tests/services/llm/test_chain_executor.py b/backend/app/tests/services/llm/test_chain_executor.py new file mode 100644 index 000000000..e8fdc31a9 --- /dev/null +++ b/backend/app/tests/services/llm/test_chain_executor.py @@ -0,0 +1,215 @@ +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest + +from app.models.llm.request import ( + LLMChainRequest, + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + ChainStatus, +) +from app.models.llm.request import ChainBlock as ChainBlockModel +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + Usage, + TextOutput, + TextContent, +) +from app.models import JobStatus +from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain +from app.services.llm.chain.executor import ChainExecutor +from app.services.llm.chain.types import BlockResult + + +@pytest.fixture +def context(): + return ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url="https://example.com/callback", + total_blocks=1, + ) + + +@pytest.fixture +def request_obj(): + return LLMChainRequest( + query=QueryParams(input="hello"), + blocks=[ + ChainBlockModel( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + ) + ], + callback_url="https://example.com/callback", + ) + + +@pytest.fixture +def text_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-1", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=TextContent(value="Response text")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + +@pytest.fixture +def success_result(text_response): + return BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + ) + + +@pytest.fixture +def failure_result(): + return BlockResult(error="Provider failed") + + +class TestChainExecutor: + def _make_executor(self, context, request_obj, chain_result): + mock_chain = MagicMock(spec=LLMChain) + mock_chain.execute.return_value = chain_result + return ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + def test_run_success_with_callback(self, context, request_obj, success_result): + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is True + mock_callback.assert_called_once() + # Verify chain status updated to COMPLETED + completed_call = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.COMPLETED + ] + assert len(completed_call) == 1 + + def test_run_success_without_callback(self, context, request_obj, success_result): + request_obj.callback_url = None + context.callback_url = None + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is True + mock_callback.assert_not_called() + + def test_run_failure_updates_status(self, context, request_obj, failure_result): + executor = self._make_executor(context, request_obj, failure_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is False + assert result["error"] == "Provider failed" + # Verify chain status updated to FAILED + failed_call = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.FAILED + ] + assert len(failed_call) == 1 + + def test_run_failure_sends_callback(self, context, request_obj, failure_result): + executor = self._make_executor(context, request_obj, failure_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + mock_callback.assert_called_once() + + def test_run_unexpected_exception_handled(self, context, request_obj): + mock_chain = MagicMock(spec=LLMChain) + mock_chain.execute.side_effect = RuntimeError("Something broke") + executor = ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is False + assert "Unexpected error occurred" in result["error"] + + def test_setup_updates_job_and_chain_status( + self, context, request_obj, success_result + ): + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + patch("app.services.llm.chain.executor.JobCrud") as mock_job_crud, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + executor.run() + + # _setup should set chain to RUNNING + running_calls = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.RUNNING + ] + assert len(running_calls) == 1 diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index f448be9b2..8cef08e96 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -23,11 +23,14 @@ # KaapiLLMParams, KaapiCompletionConfig, ) -from app.models.llm.request import ConfigBlob, LLMCallConfig +from app.models.llm.request import ConfigBlob, LLMCallConfig, LLMChainRequest +from app.models.llm.request import ChainBlock as ChainBlockModel from app.services.llm.jobs import ( start_job, + start_chain_job, handle_job_error, execute_job, + execute_chain_job, resolve_config_blob, ) from app.tests.utils.utils import get_project @@ -1161,6 +1164,207 @@ def test_execute_job_continues_when_no_validator_configs_resolved( mock_guardrails.assert_not_called() +class TestStartChainJob: + """Test cases for the start_chain_job function.""" + + @pytest.fixture + def chain_request(self): + return LLMChainRequest( + query=QueryParams(input="Test query"), + blocks=[ + ChainBlockModel( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + ) + ], + ) + + def test_start_chain_job_success(self, db: Session, chain_request): + project = get_project(db) + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_schedule, + patch("app.services.llm.jobs.JobCrud") as mock_job_crud_class, + ): + mock_schedule.return_value = "fake-task-id" + mock_job = MagicMock() + mock_job.id = uuid4() + mock_job.job_type = JobType.LLM_CHAIN + mock_job.status = JobStatus.PENDING + mock_job_crud_class.return_value.create.return_value = mock_job + + job_id = start_chain_job( + db, chain_request, project.id, project.organization_id + ) + + assert job_id == mock_job.id + mock_schedule.assert_called_once() + _, kwargs = mock_schedule.call_args + assert kwargs["function_path"] == "app.services.llm.jobs.execute_chain_job" + + def test_start_chain_job_celery_failure(self, db: Session, chain_request): + project = get_project(db) + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_schedule, + patch("app.services.llm.jobs.JobCrud") as mock_job_crud_class, + ): + mock_schedule.side_effect = Exception("Celery connection failed") + mock_job = MagicMock() + mock_job.id = uuid4() + mock_job_crud_class.return_value.create.return_value = mock_job + + with pytest.raises(HTTPException) as exc_info: + start_chain_job(db, chain_request, project.id, project.organization_id) + + assert exc_info.value.status_code == 500 + assert "Internal server error while executing LLM chain job" in str( + exc_info.value.detail + ) + + +class TestExecuteChainJob: + """Test suite for execute_chain_job.""" + + @pytest.fixture + def chain_request_data(self): + return { + "query": {"input": "Test query"}, + "blocks": [ + { + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4"}, + } + } + }, + } + ], + } + + @pytest.fixture + def mock_llm_response(self): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-123", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=TextContent(value="Test response")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + def _execute_chain_job(self, request_data): + return execute_chain_job( + request_data=request_data, + project_id=1, + organization_id=1, + job_id=str(uuid4()), + task_id="task-123", + task_instance=None, + ) + + def test_success_flow(self, chain_request_data, mock_llm_response): + from app.services.llm.chain.types import BlockResult + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch("app.services.llm.chain.executor.Session") as mock_executor_session, + patch("app.services.llm.chain.executor.send_callback"), + patch("app.services.llm.chain.executor.update_llm_chain_status"), + patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute_llm, + patch("app.services.llm.chain.chain.Session"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + mock_executor_session.return_value.__enter__.return_value = MagicMock() + mock_executor_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = uuid4() + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + + mock_execute_llm.return_value = BlockResult( + response=mock_llm_response, + llm_call_id=uuid4(), + usage=mock_llm_response.usage, + ) + + result = self._execute_chain_job(chain_request_data) + + assert result["success"] is True + + def test_exception_during_chain_creation(self, chain_request_data): + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch( + "app.services.llm.jobs.create_llm_chain", + side_effect=Exception("DB error"), + ), + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + mock_handle_error.return_value = { + "success": False, + "error": "Unexpected error occurred", + } + + result = self._execute_chain_job(chain_request_data) + + assert result["success"] is False + + def test_chain_status_updated_to_failed_on_error(self, chain_request_data): + chain_id = uuid4() + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch( + "app.services.llm.jobs.update_llm_chain_status" + ) as mock_update_status, + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + patch( + "app.services.llm.chain.chain.LLMChain", + side_effect=Exception("Chain init error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = chain_id + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + mock_handle_error.return_value = { + "success": False, + "error": "Unexpected error occurred", + } + + result = self._execute_chain_job(chain_request_data) + + mock_update_status.assert_called_once() + _, kwargs = mock_update_status.call_args + assert kwargs["chain_id"] == chain_id + assert kwargs["status"].value == "failed" + + class TestResolveConfigBlob: """Test suite for resolve_config_blob function.""" From 64214656c9c60e32de2da4252283d83c1d7d2380 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:36:37 +0530 Subject: [PATCH 08/11] fix: correct variable name from job_id to job_uuid in execute_job function --- backend/app/services/llm/jobs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index b7a2f0226..221b8b32d 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -588,7 +588,7 @@ def execute_job( result = execute_llm_call( config=request.config, query=request.query, - job_id=job_id, + job_id=job_uuid, project_id=project_id, organization_id=organization_id, request_metadata=request.request_metadata, @@ -608,7 +608,7 @@ def execute_job( with Session(engine) as session: JobCrud(session=session).update( - job_id=job_id, job_update=JobUpdate(status=JobStatus.SUCCESS) + job_id=job_uuid, job_update=JobUpdate(status=JobStatus.SUCCESS) ) logger.info( f"[execute_job] Successfully completed LLM job | job_id={job_id}, " @@ -631,7 +631,7 @@ def execute_job( f"[execute_job] Unexpected error: {str(e)} | job_id={job_uuid}, task_id={task_id}", exc_info=True, ) - return handle_job_error(job_id, request.callback_url, callback_response) + return handle_job_error(job_uuid, callback_url_str, callback_response) def execute_chain_job( From 19d6f5888c77ed1969dfe70ea98626c00fecdf71 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:18:53 +0530 Subject: [PATCH 09/11] refactor: streamline LLM chain execution and enhance callback handling --- backend/app/api/docs/llm/llm_chain.md | 2 - backend/app/models/llm/request.py | 14 +- backend/app/services/llm/chain/chain.py | 155 +++---------- backend/app/services/llm/chain/executor.py | 64 +++++- backend/app/services/llm/jobs.py | 67 +----- backend/app/tests/services/llm/test_chain.py | 217 ++++-------------- .../tests/services/llm/test_chain_executor.py | 166 ++++++++++++++ backend/app/tests/services/llm/test_jobs.py | 2 +- 8 files changed, 316 insertions(+), 371 deletions(-) diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md index d6c17893c..1d17f24bf 100644 --- a/backend/app/api/docs/llm/llm_chain.md +++ b/backend/app/api/docs/llm/llm_chain.md @@ -52,8 +52,6 @@ for processing, and results are delivered via the callback URL when complete. - Passed through unchanged in the response ### Note -- Input guardrails from the first block's config are applied before chain execution starts -- Output guardrails from the last block's config are applied after all blocks complete - If any block fails, the chain stops immediately — no subsequent blocks are executed - `warnings` list is automatically added in response metadata when using Kaapi configs if any parameters are suppressed or adjusted (e.g., temperature on reasoning models) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index eaa834440..cc8b11e81 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,19 +1,13 @@ +from datetime import datetime from enum import Enum from typing import Annotated, Any, Literal, Union - from uuid import UUID, uuid4 -from sqlmodel import Field, SQLModel -from pydantic import Discriminator, model_validator, HttpUrl -from datetime import datetime -from app.core.util import now import sqlalchemy as sa -from typing import Annotated, Any, List, Literal, Union -from uuid import UUID, uuid4 -from pydantic import model_validator, HttpUrl -from datetime import datetime +from pydantic import HttpUrl, model_validator from sqlalchemy.dialects.postgresql import JSONB -from sqlmodel import Field, SQLModel, Index, text +from sqlmodel import Field, Index, SQLModel, text + from app.core.util import now diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py index 98457fff6..ad0503675 100644 --- a/backend/app/services/llm/chain/chain.py +++ b/backend/app/services/llm/chain/chain.py @@ -1,12 +1,8 @@ import logging from dataclasses import dataclass, field -from typing import Any +from typing import Any, Callable from uuid import UUID -from sqlmodel import Session - -from app.core.db import engine -from app.crud.llm_chain import update_llm_chain_block_completed from app.models.llm.request import ( LLMCallConfig, QueryParams, @@ -15,14 +11,12 @@ AudioInput, ) from app.models.llm.response import ( - IntermediateChainResponse, TextOutput, AudioOutput, Usage, ) from app.services.llm.chain.types import BlockResult from app.services.llm.jobs import execute_llm_call -from app.utils import APIResponse, send_callback logger = logging.getLogger(__name__) @@ -30,13 +24,13 @@ @dataclass class ChainContext: - """Shared state passed to all blocks. Accumulates responses.""" + """Shared state for chain execution.""" job_id: UUID chain_id: UUID project_id: int organization_id: int - callback_url: str + callback_url: str | None total_blocks: int langfuse_credentials: dict[str, Any] | None = None @@ -50,60 +44,6 @@ class ChainContext: ) ) - def on_block_completed(self, block_index: int, result: BlockResult) -> None: - """Called after each block completes. Updates chain state in DB and sends intermediate callback.""" - - if result.usage: - self.aggregated_usage.input_tokens += result.usage.input_tokens - self.aggregated_usage.output_tokens += result.usage.output_tokens - self.aggregated_usage.total_tokens += result.usage.total_tokens - - if result.success and result.llm_call_id: - with Session(engine) as session: - update_llm_chain_block_completed( - session, - chain_id=self.chain_id, - llm_call_id=result.llm_call_id, - ) - - if ( - block_index < len(self.intermediate_callback_flags) - and self.intermediate_callback_flags[block_index] - and self.callback_url - and block_index < self.total_blocks - 1 - ): - self._send_intermediate_callback(block_index, result) - - def _send_intermediate_callback( - self, block_index: int, result: BlockResult - ) -> None: - """Send intermediate callback for a completed block.""" - try: - intermediate = IntermediateChainResponse( - block_index=block_index + 1, - total_blocks=self.total_blocks, - response=result.response.response, - usage=result.usage, - provider_raw_response=result.response.provider_raw_response, - ) - callback_data = APIResponse.success_response( - data=intermediate, - metadata=self.request_metadata, - ) - send_callback( - callback_url=self.callback_url, - data=callback_data.model_dump(), - ) - logger.info( - f"[ChainContext] Sent intermediate callback | " - f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" - ) - except Exception as e: - logger.warning( - f"[ChainContext] Failed to send intermediate callback: {e} | " - f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" - ) - def result_to_query(result: BlockResult) -> QueryParams: """Convert a block's output into the next block's QueryParams. @@ -124,11 +64,7 @@ def result_to_query(result: BlockResult) -> QueryParams: class ChainBlock: - """A single node in the linked chain. - - Wraps execute_block() with linking capability. - Each block knows its next block and forwards output to it. - """ + """A single block in the chain. Only responsible for executing itself.""" def __init__( self, @@ -142,25 +78,15 @@ def __init__( self._index = index self._context = context self._include_provider_raw_response = include_provider_raw_response - self._next: ChainBlock | None = None - - def link(self, next_block: "ChainBlock") -> "ChainBlock": - """Link to the next block in the chain.""" - self._next = next_block - return next_block def execute(self, query: QueryParams) -> BlockResult: - """Execute this block, then flow to next. - - No loop. Each block calls the next via the linked reference. - Data flows through the chain like a linked list traversal. - """ + """Execute this block and return the result.""" logger.info( f"[ChainBlock.execute] Executing block {self._index} | " f"job_id={self._context.job_id}" ) - result = execute_llm_call( + return execute_llm_call( config=self._config, query=query, job_id=self._context.job_id, @@ -172,51 +98,40 @@ def execute(self, query: QueryParams) -> BlockResult: chain_id=self._context.chain_id, ) - self._context.on_block_completed(self._index, result) - if not result.success: - logger.error( - f"[ChainBlock.execute] Block {self._index} failed: {result.error} | " - f"job_id={self._context.job_id}" - ) - return result +class LLMChain: + """Orchestrates sequential execution of ChainBlocks.""" - if self._next: - next_query = result_to_query(result) - return self._next.execute(next_query) + def __init__(self, blocks: list[ChainBlock], context: ChainContext): + self._blocks = blocks + self._context = context - logger.info( - f"[ChainBlock.execute] Block {self._index} is the last block | " - f"job_id={self._context.job_id}" - ) - return result + def execute( + self, + query: QueryParams, + on_block_completed: Callable[[int, BlockResult], None] | None = None, + ) -> BlockResult: + """Execute blocks sequentially, passing output of each to the next.""" + if not self._blocks: + return BlockResult(error="Chain has no blocks") + current_query = query + result: BlockResult | None = None -class LLMChain: - """Links ChainBlocks together into a sequential chain. + for block in self._blocks: + result = block.execute(current_query) - Construction builds the linked structure. - Execution pushes input into the head — it flows through to the tail. - """ + if on_block_completed: + on_block_completed(block._index, result) - def __init__(self, blocks: list[ChainBlock]): - self._head: ChainBlock | None = None - self._tail: ChainBlock | None = None - self._link_blocks(blocks) - - def _link_blocks(self, blocks: list[ChainBlock]) -> None: - """Link all blocks in sequence.""" - if not blocks: - return - self._head = blocks[0] - self._tail = blocks[-1] - prev = blocks[0] - for curr in blocks[1:]: - prev.link(curr) - prev = curr + if not result.success: + logger.error( + f"[LLMChain.execute] Block {block._index} failed: {result.error} | " + f"job_id={self._context.job_id}" + ) + return result - def execute(self, query: QueryParams) -> BlockResult: - """Push input into the chain head. It flows through to the tail.""" - if not self._head: - return BlockResult(error="Chain has no blocks") - return self._head.execute(query) + if block is not self._blocks[-1]: + current_query = result_to_query(result) + + return result diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py index 25208aad6..27ab8de86 100644 --- a/backend/app/services/llm/chain/executor.py +++ b/backend/app/services/llm/chain/executor.py @@ -4,13 +4,13 @@ from app.core.db import engine from app.crud.jobs import JobCrud -from app.crud.llm_chain import update_llm_chain_status +from app.crud.llm_chain import update_llm_chain_block_completed, update_llm_chain_status from app.models import JobStatus, JobUpdate from app.models.llm.request import ( ChainStatus, LLMChainRequest, ) -from app.models.llm.response import LLMChainResponse +from app.models.llm.response import IntermediateChainResponse, LLMChainResponse from app.services.llm.chain.chain import ChainContext, LLMChain from app.services.llm.chain.types import BlockResult from app.utils import APIResponse, send_callback @@ -37,7 +37,10 @@ def run(self) -> dict: try: self._setup() - result = self._chain.execute(self._request.query) + result = self._chain.execute( + self._request.query, + on_block_completed=self._on_block_completed, + ) return self._teardown(result) @@ -96,7 +99,7 @@ def _handle_error(self, error: str) -> dict: metadata=self._request.request_metadata, ) logger.error( - f"[ChainExecutor] Chain execution failed | " + f"[_handle_error] Chain execution failed | " f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}" ) @@ -121,6 +124,59 @@ def _handle_error(self, error: str) -> dict: ) return callback_response.model_dump() + def _on_block_completed(self, block_index: int, result: BlockResult) -> None: + """Handle side effects after each block completes.""" + if result.usage: + self._context.aggregated_usage.input_tokens += result.usage.input_tokens + self._context.aggregated_usage.output_tokens += result.usage.output_tokens + self._context.aggregated_usage.total_tokens += result.usage.total_tokens + + if result.success and result.llm_call_id: + with Session(engine) as session: + update_llm_chain_block_completed( + session, + chain_id=self._context.chain_id, + llm_call_id=result.llm_call_id, + ) + + if ( + block_index < len(self._context.intermediate_callback_flags) + and self._context.intermediate_callback_flags[block_index] + and self._request.callback_url + and block_index < self._context.total_blocks - 1 + ): + self._send_intermediate_callback(block_index, result) + + def _send_intermediate_callback( + self, block_index: int, result: BlockResult + ) -> None: + """Send intermediate callback for a completed block.""" + try: + intermediate = IntermediateChainResponse( + block_index=block_index + 1, + total_blocks=self._context.total_blocks, + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_data = APIResponse.success_response( + data=intermediate, + metadata=self._context.request_metadata, + ) + send_callback( + callback_url=str(self._request.callback_url), + data=callback_data.model_dump(), + ) + logger.info( + f"[_send_intermediate_callback] Sent intermediate callback | " + f"block={block_index + 1}/{self._context.total_blocks}, job_id={self._context.job_id}" + ) + except Exception as e: + logger.warning( + f"[_send_intermediate_callback] Failed to send intermediate callback: {e} | " + f"block={block_index + 1}/{self._context.total_blocks}, job_id={self._context.job_id}" + ) + def _handle_unexpected_error(self, e: Exception) -> dict: logger.error( f"[ChainExecutor.run] Unexpected error: {e} | " diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 221b8b32d..2a5f7dee2 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from typing import Any from uuid import UUID + from asgi_correlation_id import correlation_id from fastapi import HTTPException from sqlmodel import Session @@ -16,15 +17,15 @@ from app.crud.llm_chain import create_llm_chain, update_llm_chain_status from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest from app.models.llm.request import ( + AudioInput, ChainStatus, ConfigBlob, + ImageInput, KaapiCompletionConfig, LLMCallConfig, + PDFInput, QueryParams, TextInput, - AudioInput, - ImageInput, - PDFInput, ) from app.models.llm.response import TextOutput from app.services.llm.chain.types import BlockResult @@ -34,7 +35,7 @@ ) from app.services.llm.mappers import transform_kaapi_config_to_native from app.services.llm.providers.registry import get_llm_provider -from app.utils import APIResponse, send_callback, resolve_input, cleanup_temp_file +from app.utils import APIResponse, cleanup_temp_file, resolve_input, send_callback logger = logging.getLogger(__name__) @@ -172,55 +173,6 @@ def resolved_input_context( cleanup_temp_file(resolved_input) -def validate_text_with_guardrails( - text: str, - guardrails: list[dict[str, Any]], - job_id: UUID, - project_id: int, - organization_id: int, - guardrail_type: str, # "input" or "output" -) -> tuple[str | None, str | None]: - """Validate text against guardrails. - - Returns: - (validated_text, error_message) - - If successful: (modified_text, None) - - If failed: (None, error_message) - - If bypassed: (original_text, None) - """ - safe_result = run_guardrails_validation( - text, - guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) - - logger.info( - f"[validate_text_with_guardrails] {guardrail_type.capitalize()} guardrail validation | " - f"success={safe_result['success']}, job_id={job_id}" - ) - - if safe_result.get("bypassed"): - logger.info( - f"[validate_text_with_guardrails] Guardrails bypassed (service unavailable) | " - f"job_id={job_id}" - ) - return text, None - - if safe_result["success"]: - validated_text = safe_result["data"]["safe_text"] - - # Special case for output guardrails: check if rephrase is needed - if guardrail_type == "output" and safe_result["data"].get("rephrase_needed"): - return None, "Output requires rephrasing" - - return validated_text, None - - return None, safe_result["error"] - - def resolve_config_blob( config_crud: ConfigVersionCrud, config: LLMCallConfig ) -> tuple[ConfigBlob | None, str | None]: @@ -438,14 +390,14 @@ def execute_llm_call( ) try: - temp_request = LLMCallRequest( + llm_call_request = LLMCallRequest( query=query, config=config, request_metadata=request_metadata, ) llm_call = create_llm_call( session, - request=temp_request, + request=llm_call_request, job_id=job_id, project_id=project_id, organization_id=organization_id, @@ -653,6 +605,7 @@ def execute_chain_job( request = LLMChainRequest(**request_data) job_uuid = UUID(job_id) + callback_url_str = str(request.callback_url) if request.callback_url else None chain_uuid = None logger.info( @@ -709,7 +662,7 @@ def execute_chain_job( for i, block in enumerate(request.blocks) ] - chain = LLMChain(blocks) + chain = LLMChain(blocks, context) executor = ChainExecutor(chain=chain, context=context, request=request) return executor.run() @@ -740,4 +693,4 @@ def execute_chain_job( error="Unexpected error occurred", metadata=request.request_metadata, ) - return handle_job_error(job_uuid, request.callback_url, callback_response) + return handle_job_error(job_uuid, callback_url_str, callback_response) diff --git a/backend/app/tests/services/llm/test_chain.py b/backend/app/tests/services/llm/test_chain.py index d93380d84..5b5cfed3f 100644 --- a/backend/app/tests/services/llm/test_chain.py +++ b/backend/app/tests/services/llm/test_chain.py @@ -118,153 +118,13 @@ def test_unsupported_output_type_raises(self): result_to_query(result) -class TestChainContext: - def test_aggregates_usage(self, context): - usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) - result = BlockResult( - response=MagicMock(), llm_call_id=uuid4(), usage=usage, error=None - ) - - with patch("app.services.llm.chain.chain.Session"): - context.on_block_completed(0, result) - - assert context.aggregated_usage.input_tokens == 10 - assert context.aggregated_usage.output_tokens == 20 - assert context.aggregated_usage.total_tokens == 30 - - def test_aggregates_usage_across_blocks(self, context): - usage1 = Usage(input_tokens=10, output_tokens=20, total_tokens=30) - usage2 = Usage(input_tokens=5, output_tokens=15, total_tokens=20) - - result1 = BlockResult( - response=MagicMock(), llm_call_id=uuid4(), usage=usage1, error=None - ) - result2 = BlockResult( - response=MagicMock(), llm_call_id=uuid4(), usage=usage2, error=None - ) - - with patch("app.services.llm.chain.chain.Session"): - context.on_block_completed(0, result1) - context.on_block_completed(1, result2) - - assert context.aggregated_usage.input_tokens == 15 - assert context.aggregated_usage.total_tokens == 50 - - def test_updates_db_on_success(self, context): - llm_call_id = uuid4() - result = BlockResult( - response=MagicMock(), llm_call_id=llm_call_id, usage=MagicMock(), error=None - ) - - with patch("app.services.llm.chain.chain.Session") as mock_session, patch( - "app.services.llm.chain.chain.update_llm_chain_block_completed" - ) as mock_update: - mock_session.return_value.__enter__.return_value = MagicMock() - context.on_block_completed(0, result) - - mock_update.assert_called_once_with( - mock_session.return_value.__enter__.return_value, - chain_id=context.chain_id, - llm_call_id=llm_call_id, - ) - - def test_sends_intermediate_callback(self, context, text_response): - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch("app.services.llm.chain.chain.send_callback") as mock_callback, - ): - mock_session.return_value.__enter__.return_value = MagicMock() - context.on_block_completed(0, result) - - mock_callback.assert_called_once() - call_kwargs = mock_callback.call_args[1] - assert call_kwargs["callback_url"] == "https://example.com/callback" - - def test_skips_intermediate_callback_for_last_block(self, context, text_response): - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch("app.services.llm.chain.chain.send_callback") as mock_callback, - ): - mock_session.return_value.__enter__.return_value = MagicMock() - # Block index 2 = last block (total_blocks=3) - context.on_block_completed(2, result) - - mock_callback.assert_not_called() - - def test_skips_intermediate_callback_when_flag_false(self, context, text_response): - context.intermediate_callback_flags = [False, True, False] - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch("app.services.llm.chain.chain.send_callback") as mock_callback, - ): - mock_session.return_value.__enter__.return_value = MagicMock() - context.on_block_completed(0, result) - - mock_callback.assert_not_called() - - def test_skips_db_update_on_error(self, context): - result = BlockResult(error="Block failed", usage=MagicMock()) - - with patch( - "app.services.llm.chain.chain.update_llm_chain_block_completed" - ) as mock_update: - context.on_block_completed(0, result) - mock_update.assert_not_called() - - def test_intermediate_callback_exception_is_swallowed(self, context, text_response): - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch( - "app.services.llm.chain.chain.send_callback", - side_effect=Exception("Connection error"), - ), - ): - mock_session.return_value.__enter__.return_value = MagicMock() - # Should not raise - context.on_block_completed(0, result) - - class TestChainBlock: def test_execute_single_block(self, context, text_response): query = QueryParams(input="test input") config = make_config() block = ChainBlock(config=config, index=0, context=context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult( response=text_response, usage=text_response.usage ) @@ -274,37 +134,15 @@ def test_execute_single_block(self, context, text_response): assert result.success mock_execute.assert_called_once() - def test_execute_chains_to_next_block(self, context, text_response): - query = QueryParams(input="test input") - config = make_config() - block1 = ChainBlock(config=config, index=0, context=context) - block2 = ChainBlock(config=config, index=1, context=context) - block1.link(block2) - - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): - mock_execute.return_value = BlockResult( - response=text_response, usage=text_response.usage - ) - - result = block1.execute(query) - - assert mock_execute.call_count == 2 - - def test_execute_stops_on_failure(self, context): + def test_execute_returns_failure(self, context): query = QueryParams(input="test input") config = make_config() - block1 = ChainBlock(config=config, index=0, context=context) - block2 = ChainBlock(config=config, index=1, context=context) - block1.link(block2) + block = ChainBlock(config=config, index=0, context=context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult(error="Provider error") - result = block1.execute(query) + result = block.execute(query) assert not result.success assert result.error == "Provider error" @@ -312,8 +150,8 @@ def test_execute_stops_on_failure(self, context): class TestLLMChain: - def test_execute_empty_chain(self): - chain = LLMChain([]) + def test_execute_empty_chain(self, context): + chain = LLMChain([], context) query = QueryParams(input="test") result = chain.execute(query) @@ -324,11 +162,9 @@ def test_execute_empty_chain(self): def test_execute_single_block_chain(self, context, text_response): config = make_config() block = ChainBlock(config=config, index=0, context=context) - chain = LLMChain([block]) + chain = LLMChain([block], context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult( response=text_response, usage=text_response.usage ) @@ -341,11 +177,9 @@ def test_execute_single_block_chain(self, context, text_response): def test_execute_multi_block_chain(self, context, text_response): config = make_config() blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] - chain = LLMChain(blocks) + chain = LLMChain(blocks, context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult( response=text_response, usage=text_response.usage ) @@ -354,3 +188,32 @@ def test_execute_multi_block_chain(self, context, text_response): assert result.success assert mock_execute.call_count == 3 + + def test_execute_stops_on_failure(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] + chain = LLMChain(blocks, context) + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult(error="Provider error") + + result = chain.execute(QueryParams(input="hello")) + + assert not result.success + assert result.error == "Provider error" + mock_execute.assert_called_once() + + def test_execute_calls_on_block_completed(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(2)] + chain = LLMChain(blocks, context) + callback = MagicMock() + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + chain.execute(QueryParams(input="hello"), on_block_completed=callback) + + assert callback.call_count == 2 diff --git a/backend/app/tests/services/llm/test_chain_executor.py b/backend/app/tests/services/llm/test_chain_executor.py index e8fdc31a9..6564ebafb 100644 --- a/backend/app/tests/services/llm/test_chain_executor.py +++ b/backend/app/tests/services/llm/test_chain_executor.py @@ -213,3 +213,169 @@ def test_setup_updates_job_and_chain_status( if c[1].get("status") == ChainStatus.RUNNING ] assert len(running_calls) == 1 + + +class TestOnBlockCompleted: + def _make_executor(self, context, request_obj): + mock_chain = MagicMock(spec=LLMChain) + return ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + def test_aggregates_usage(self, context, request_obj): + executor = self._make_executor(context, request_obj) + usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + result = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage, error=None + ) + + with patch("app.services.llm.chain.executor.Session"): + executor._on_block_completed(0, result) + + assert context.aggregated_usage.input_tokens == 10 + assert context.aggregated_usage.output_tokens == 20 + assert context.aggregated_usage.total_tokens == 30 + + def test_aggregates_usage_across_blocks(self, context, request_obj): + executor = self._make_executor(context, request_obj) + result1 = BlockResult( + response=MagicMock(), + llm_call_id=uuid4(), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + error=None, + ) + result2 = BlockResult( + response=MagicMock(), + llm_call_id=uuid4(), + usage=Usage(input_tokens=5, output_tokens=15, total_tokens=20), + error=None, + ) + + with patch("app.services.llm.chain.executor.Session"): + executor._on_block_completed(0, result1) + executor._on_block_completed(1, result2) + + assert context.aggregated_usage.input_tokens == 15 + assert context.aggregated_usage.total_tokens == 50 + + def test_updates_db_on_success(self, context, request_obj): + executor = self._make_executor(context, request_obj) + llm_call_id = uuid4() + result = BlockResult( + response=MagicMock(), llm_call_id=llm_call_id, usage=MagicMock(), error=None + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch( + "app.services.llm.chain.executor.update_llm_chain_block_completed" + ) as mock_update, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_update.assert_called_once_with( + mock_session.return_value.__enter__.return_value, + chain_id=context.chain_id, + llm_call_id=llm_call_id, + ) + + def test_skips_db_update_on_error(self, context, request_obj): + executor = self._make_executor(context, request_obj) + result = BlockResult(error="Block failed", usage=MagicMock()) + + with patch( + "app.services.llm.chain.executor.update_llm_chain_block_completed" + ) as mock_update: + executor._on_block_completed(0, result) + mock_update.assert_not_called() + + def test_sends_intermediate_callback(self, context, request_obj, text_response): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_callback.assert_called_once() + + def test_skips_intermediate_callback_for_last_block( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(2, result) + + mock_callback.assert_not_called() + + def test_skips_intermediate_callback_when_flag_false( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [False, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_callback.assert_not_called() + + def test_intermediate_callback_exception_is_swallowed( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch( + "app.services.llm.chain.executor.send_callback", + side_effect=Exception("Connection error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + # Should not raise + executor._on_block_completed(0, result) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 8cef08e96..cc67a7d6e 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -1287,7 +1287,7 @@ def test_success_flow(self, chain_request_data, mock_llm_response): patch("app.services.llm.chain.executor.send_callback"), patch("app.services.llm.chain.executor.update_llm_chain_status"), patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute_llm, - patch("app.services.llm.chain.chain.Session"), + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), ): mock_session.return_value.__enter__.return_value = MagicMock() mock_session.return_value.__exit__.return_value = None From 9cc5cf8af2dd17467568f8ad5e1d74e25edee5db Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Fri, 6 Mar 2026 09:02:16 +0530 Subject: [PATCH 10/11] docs: enhance llm_chain.md with detailed input specifications and guardrails --- backend/app/api/docs/llm/llm_chain.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md index 1d17f24bf..0f38cc658 100644 --- a/backend/app/api/docs/llm/llm_chain.md +++ b/backend/app/api/docs/llm/llm_chain.md @@ -6,7 +6,7 @@ for processing, and results are delivered via the callback URL when complete. ### Key Parameters **`query`** (required) - Initial query input for the first block in the chain: -- `input` (required, string, min 1 char): User question/prompt/query +- `input` (required): User question/prompt/query — accepts a plain string, a structured input object (`text`, `audio`, `image`, `pdf`), or a list of structured inputs - `conversation` (optional, object): Conversation configuration - `id` (optional, string): Existing conversation ID to continue - `auto_create` (optional, boolean, default false): Create new conversation if no ID provided @@ -26,8 +26,11 @@ for processing, and results are delivered via the callback URL when complete. - **Mode 2: Ad-hoc Configuration** - `blob` (object): Complete configuration object - `completion` (required, object): Completion configuration - - `provider` (required, string): Provider type - either `"openai"` (Kaapi abstraction) or `"openai-native"` (pass-through) - - `params` (required, object): Parameters structure depends on provider type (see schema for detailed structure) + - `provider` (required, string): Kaapi providers (`openai`, `google`, `sarvamai`) — params are validated and mapped internally. Native providers (`openai-native`, `google-native`, `sarvamai-native`) — params are passed through as-is + - `type` (required, string): Completion type — `"text"`, `"stt"`, or `"tts"` + - `params` (required, object): Parameters structure depends on provider and type (see schema for detailed structure) + - `input_guardrails` (optional, array): Guardrails applied to validate/sanitize input before the LLM call + - `output_guardrails` (optional, array): Guardrails applied to validate/sanitize output after the LLM call - `prompt_template` (optional, object): Template for text interpolation - `template` (required, string): Template string with `{{input}}` placeholder — replaced with the block's input before execution - **Note** From f7797d1d5785909c8876e586a59c7b17f41b1857 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:55:34 +0530 Subject: [PATCH 11/11] refactor: remove unused timestamps from LlmChain model and update related tests --- .../alembic/versions/048_create_llm_chain_table.py | 14 +------------- backend/app/crud/llm_chain.py | 5 ----- backend/app/models/llm/request.py | 14 +------------- backend/app/tests/crud/test_llm_chain.py | 3 --- 4 files changed, 2 insertions(+), 34 deletions(-) diff --git a/backend/app/alembic/versions/048_create_llm_chain_table.py b/backend/app/alembic/versions/048_create_llm_chain_table.py index ac49eb0ec..ad498d465 100644 --- a/backend/app/alembic/versions/048_create_llm_chain_table.py +++ b/backend/app/alembic/versions/048_create_llm_chain_table.py @@ -107,19 +107,7 @@ def upgrade() -> None: comment="Future-proof extensibility catch-all", ), sa.Column( - "started_at", - sa.DateTime(), - nullable=True, - comment="Timestamp when chain execution started", - ), - sa.Column( - "completed_at", - sa.DateTime(), - nullable=True, - comment="Timestamp when chain execution completed", - ), - sa.Column( - "created_at", + "inserted_at", sa.DateTime(), nullable=False, comment="Timestamp when the chain record was created", diff --git a/backend/app/crud/llm_chain.py b/backend/app/crud/llm_chain.py index 77ab70987..010d8abbd 100644 --- a/backend/app/crud/llm_chain.py +++ b/backend/app/crud/llm_chain.py @@ -85,18 +85,13 @@ def update_llm_chain_status( db_chain.status = status db_chain.updated_at = now() - if status == ChainStatus.RUNNING: - db_chain.started_at = now() - if status == ChainStatus.FAILED: db_chain.error = error db_chain.total_usage = total_usage - db_chain.completed_at = now() if status == ChainStatus.COMPLETED: db_chain.output = output db_chain.total_usage = total_usage - db_chain.completed_at = now() session.add(db_chain) session.commit() diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 1760e569f..0a8c33818 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -744,19 +744,7 @@ class LlmChain(SQLModel, table=True): ), ) - started_at: datetime | None = Field( - default=None, - nullable=True, - sa_column_kwargs={"comment": "Timestamp when chain execution started"}, - ) - - completed_at: datetime | None = Field( - default=None, - nullable=True, - sa_column_kwargs={"comment": "Timestamp when chain execution completed"}, - ) - - created_at: datetime = Field( + inserted_at: datetime = Field( default_factory=now, nullable=False, sa_column_kwargs={"comment": "Timestamp when the chain record was created"}, diff --git a/backend/app/tests/crud/test_llm_chain.py b/backend/app/tests/crud/test_llm_chain.py index 84324f86c..dfeceeee4 100644 --- a/backend/app/tests/crud/test_llm_chain.py +++ b/backend/app/tests/crud/test_llm_chain.py @@ -67,7 +67,6 @@ def test_update_to_running(self, db: Session, chain): ) assert updated.status == ChainStatus.RUNNING - assert updated.started_at is not None def test_update_to_completed(self, db: Session, chain): output = {"type": "text", "content": {"value": "result"}} @@ -84,7 +83,6 @@ def test_update_to_completed(self, db: Session, chain): assert updated.status == ChainStatus.COMPLETED assert updated.output == output assert updated.total_usage == usage - assert updated.completed_at is not None def test_update_to_failed(self, db: Session, chain): usage = {"input_tokens": 5, "output_tokens": 0, "total_tokens": 5} @@ -100,7 +98,6 @@ def test_update_to_failed(self, db: Session, chain): assert updated.status == ChainStatus.FAILED assert updated.error == "Provider timeout" assert updated.total_usage == usage - assert updated.completed_at is not None def test_raises_for_missing_chain(self, db: Session): with pytest.raises(ValueError, match="LLM chain not found"):