diff --git a/backend/app/alembic/versions/048_create_llm_chain_table.py b/backend/app/alembic/versions/048_create_llm_chain_table.py new file mode 100644 index 000000000..ad498d465 --- /dev/null +++ b/backend/app/alembic/versions/048_create_llm_chain_table.py @@ -0,0 +1,169 @@ +"""Create llm_chain table + +Revision ID: 048 +Revises: 047 +Create Date: 2026-02-20 00:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB + +revision = "048" +down_revision = "047" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # 1. Create llm_chain table + op.create_table( + "llm_chain", + sa.Column( + "id", + sa.Uuid(), + nullable=False, + comment="Unique identifier for the LLM chain record", + ), + sa.Column( + "job_id", + sa.Uuid(), + nullable=False, + comment="Reference to the parent job (status tracked in job table)", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project this LLM call belongs to", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization this LLM call belongs to", + ), + sa.Column( + "status", + sa.String(), + nullable=False, + server_default="pending", + comment="Chain execution status (pending, running, failed, completed)", + ), + sa.Column( + "error", + sa.Text(), + nullable=True, + comment="Error message if the chain execution failed", + ), + sa.Column( + "block_sequences", + JSONB(), + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + sa.Column( + "total_blocks", + sa.Integer(), + nullable=False, + comment="Total number of blocks to execute", + ), + sa.Column( + "number_of_blocks_processed", + sa.Integer(), + nullable=False, + server_default="0", + comment="Number of blocks processed so far (used for tracking progress)", + ), + sa.Column( + "input", + sa.String(), + nullable=False, + comment="First block user's input - text string, binary data, or file path for multimodal", + ), + sa.Column( + "output", + JSONB(), + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + sa.Column( + "configs", + JSONB(), + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + sa.Column( + "total_usage", + JSONB(), + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + sa.Column( + "metadata", + JSONB(), + nullable=True, + comment="Future-proof extensibility catch-all", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was last updated", + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["job_id"], ["job.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + ) + + op.create_index( + "idx_llm_chain_job_id", + "llm_chain", + ["job_id"], + ) + + # 2. Add chain_id FK column to llm_call table + op.add_column( + "llm_call", + sa.Column( + "chain_id", + sa.Uuid(), + nullable=True, + comment="Reference to the parent chain (NULL for standalone /llm/call requests)", + ), + ) + op.create_foreign_key( + "fk_llm_call_chain_id", + "llm_call", + "llm_chain", + ["chain_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_index( + "idx_llm_call_chain_id", + "llm_call", + ["chain_id"], + postgresql_where=sa.text("chain_id IS NOT NULL"), + ) + + op.execute("ALTER TYPE jobtype ADD VALUE IF NOT EXISTS 'LLM_CHAIN'") + + +def downgrade() -> None: + op.drop_index("idx_llm_call_chain_id", table_name="llm_call") + op.drop_constraint("fk_llm_call_chain_id", "llm_call", type_="foreignkey") + op.drop_column("llm_call", "chain_id") + + op.drop_index("idx_llm_chain_job_id", table_name="llm_chain") + op.drop_table("llm_chain") diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md new file mode 100644 index 000000000..0f38cc658 --- /dev/null +++ b/backend/app/api/docs/llm/llm_chain.md @@ -0,0 +1,61 @@ +Execute a chain of LLM calls sequentially, where each block's output becomes the next block's input. + +This endpoint initiates an asynchronous LLM chain job. The request is queued +for processing, and results are delivered via the callback URL when complete. + +### Key Parameters + +**`query`** (required) - Initial query input for the first block in the chain: +- `input` (required): User question/prompt/query — accepts a plain string, a structured input object (`text`, `audio`, `image`, `pdf`), or a list of structured inputs +- `conversation` (optional, object): Conversation configuration + - `id` (optional, string): Existing conversation ID to continue + - `auto_create` (optional, boolean, default false): Create new conversation if no ID provided + - **Note**: Cannot specify both `id` and `auto_create=true` + + +**`blocks`** (required, array, min 1 block) - Ordered list of blocks to execute sequentially. Each block contains: + +- `config` (required) - Configuration for this block's LLM call (just choose one mode): + + - **Mode 1: Stored Configuration** + - `id` (UUID): Configuration ID + - `version` (integer >= 1): Version number + - **Both required together** + - **Note**: When using stored configuration, do not include the `blob` field in the request body + + - **Mode 2: Ad-hoc Configuration** + - `blob` (object): Complete configuration object + - `completion` (required, object): Completion configuration + - `provider` (required, string): Kaapi providers (`openai`, `google`, `sarvamai`) — params are validated and mapped internally. Native providers (`openai-native`, `google-native`, `sarvamai-native`) — params are passed through as-is + - `type` (required, string): Completion type — `"text"`, `"stt"`, or `"tts"` + - `params` (required, object): Parameters structure depends on provider and type (see schema for detailed structure) + - `input_guardrails` (optional, array): Guardrails applied to validate/sanitize input before the LLM call + - `output_guardrails` (optional, array): Guardrails applied to validate/sanitize output after the LLM call + - `prompt_template` (optional, object): Template for text interpolation + - `template` (required, string): Template string with `{{input}}` placeholder — replaced with the block's input before execution + - **Note** + - When using ad-hoc configuration, do not include `id` and `version` fields + - When using the Kaapi abstraction, parameters that are not supported by the selected provider or model are automatically suppressed. If any parameters are ignored, a list of warnings is included in the metadata.warnings. + - **Recommendation**: Use stored configs (Mode 1) for production; use ad-hoc configs only for testing/validation + - **Schema**: Check the API schema or examples below for the complete parameter structure for each provider type + +- `include_provider_raw_response` (optional, boolean, default false): + - When true, includes the unmodified raw response from the LLM provider for this block + +- `intermediate_callback` (optional, boolean, default false): + - When true, sends an intermediate callback after this block completes with the block's response, usage, and position in the chain + +**`callback_url`** (optional, HTTPS URL): +- Webhook endpoint to receive the final response and intermediate callbacks +- Must be a valid HTTPS URL +- If not provided, response is only accessible through job status + +**`request_metadata`** (optional, object): +- Custom JSON metadata +- Passed through unchanged in the response + +### Note +- If any block fails, the chain stops immediately — no subsequent blocks are executed +- `warnings` list is automatically added in response metadata when using Kaapi configs if any parameters are suppressed or adjusted (e.g., temperature on reasoning models) + +--- diff --git a/backend/app/api/main.py b/backend/app/api/main.py index ed58e57f2..5ab1cbd9e 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -10,6 +10,7 @@ login, languages, llm, + llm_chain, organization, openai_conversation, project, @@ -41,6 +42,7 @@ api_router.include_router(evaluations.router) api_router.include_router(languages.router) api_router.include_router(llm.router) +api_router.include_router(llm_chain.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) api_router.include_router(openai_conversation.router) diff --git a/backend/app/api/routes/llm_chain.py b/backend/app/api/routes/llm_chain.py new file mode 100644 index 000000000..dc12062d7 --- /dev/null +++ b/backend/app/api/routes/llm_chain.py @@ -0,0 +1,62 @@ +import logging + +from fastapi import APIRouter, Depends +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.models import LLMChainRequest, LLMChainResponse, Message +from app.services.llm.jobs import start_chain_job +from app.utils import APIResponse, validate_callback_url, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["LLM"]) +llm_callback_router = APIRouter() + + +@llm_callback_router.post( + "{$callback_url}", + name="llm_chain_callback", +) +def llm_callback_notification(body: APIResponse[LLMChainResponse]): + """ + Callback endpoint specification for LLM chain completion. + + The callback will receive: + - On success: APIResponse with success=True and data containing LLMChainResponse + - On failure: APIResponse with success=False and error message + - metadata field will always be included if provided in the request + """ + ... + + +@router.post( + "/llm/chain", + description=load_description("llm/llm_chain.md"), + response_model=APIResponse[Message], + callbacks=llm_callback_router.routes, + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def llm_chain( + _current_user: AuthContextDep, _session: SessionDep, request: LLMChainRequest +): + """ + Endpoint to initiate an LLM chain as a background job. + """ + project_id = _current_user.project_.id + organization_id = _current_user.organization_.id + + if request.callback_url: + validate_callback_url(str(request.callback_url)) + + start_chain_job( + db=_session, + request=request, + project_id=project_id, + organization_id=organization_id, + ) + + return APIResponse.success_response( + data=Message( + message="Your response is being generated and will be delivered via callback." + ), + ) diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index 360bab4f2..e0ca2b171 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -48,6 +48,7 @@ def create_llm_call( *, request: LLMCallRequest, job_id: UUID, + chain_id: UUID | None = None, project_id: int, organization_id: int, resolved_config: ConfigBlob, @@ -128,6 +129,7 @@ def create_llm_call( job_id=job_id, project_id=project_id, organization_id=organization_id, + chain_id=chain_id, input=serialize_input(request.query.input), input_type=input_type, output_type=output_type, diff --git a/backend/app/crud/llm_chain.py b/backend/app/crud/llm_chain.py new file mode 100644 index 000000000..010d8abbd --- /dev/null +++ b/backend/app/crud/llm_chain.py @@ -0,0 +1,146 @@ +import logging +from typing import Any +from uuid import UUID + +from sqlmodel import Session + +from app.core.util import now +from app.models.llm.request import ChainStatus, LlmChain + +logger = logging.getLogger(__name__) + + +def create_llm_chain( + session: Session, + *, + job_id: UUID, + project_id: int, + organization_id: int, + total_blocks: int, + input: str, + configs: list[dict[str, Any]], +) -> LlmChain: + """Create a new LLM chain record. + Args: + session: Database session + job_id: Reference to the parent job + project_id: Reference to the project + organization_id: Reference to the organization + total_blocks: Total number of blocks to execute + input: Serialized input string (via serialize_input) + configs: Ordered list of block configs as submitted + + Returns: + LlmChain: The created chain record + """ + db_llm_chain = LlmChain( + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + status=ChainStatus.PENDING, + total_blocks=total_blocks, + number_of_blocks_processed=0, + input=input, + configs=configs, + block_sequences=[], + ) + + session.add(db_llm_chain) + session.commit() + session.refresh(db_llm_chain) + + logger.info( + f"[create_llm_chain] Created LLM chain id={db_llm_chain.id}, " + f"job_id={job_id}, total_blocks={total_blocks}" + ) + + return db_llm_chain + + +def update_llm_chain_status( + session: Session, + *, + chain_id: UUID, + status: ChainStatus, + output: dict[str, Any] | None = None, + total_usage: dict[str, Any] | None = None, + error: str | None = None, +) -> LlmChain: + """Update chain record status and related fields. + Args: + session: Database session + chain_id: The chain record ID + status: New chain status + output: Last block's output dict (only for COMPLETED) + total_usage: Aggregated token usage across all blocks (for COMPLETED/FAILED) + error: Error message (only for FAILED) + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + db_chain.status = status + db_chain.updated_at = now() + + if status == ChainStatus.FAILED: + db_chain.error = error + db_chain.total_usage = total_usage + + if status == ChainStatus.COMPLETED: + db_chain.output = output + db_chain.total_usage = total_usage + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_status] Chain {chain_id} → {status.value} | " + f"has_output={output is not None}, " + f"blocks={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"error={error}" + ) + return db_chain + + +def update_llm_chain_block_completed( + session: Session, + *, + chain_id: UUID, + llm_call_id: UUID, +) -> LlmChain: + """Update chain progress after a block completes. + Args: + session: Database session + chain_id: The chain record ID + llm_call_id: The llm_call record ID for the completed block + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + # Append to block_sequences + sequences = list(db_chain.block_sequences or []) + sequences.append(str(llm_call_id)) + db_chain.block_sequences = sequences + + # Increment progress + db_chain.number_of_blocks_processed = len(sequences) + db_chain.updated_at = now() + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_block_completed] Chain {chain_id} | " + f"block={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"llm_call_id={llm_call_id}" + ) + return db_chain diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index e5f3e4270..6871c03c4 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -111,6 +111,9 @@ LLMCallRequest, LLMCallResponse, LlmCall, + LLMChainRequest, + LLMChainResponse, + LlmChain, ) from .message import Message diff --git a/backend/app/models/job.py b/backend/app/models/job.py index b6a1a5ae7..3b20249f5 100644 --- a/backend/app/models/job.py +++ b/backend/app/models/job.py @@ -17,6 +17,7 @@ class JobStatus(str, Enum): class JobType(str, Enum): RESPONSE = "RESPONSE" LLM_API = "LLM_API" + LLM_CHAIN = "LLM_CHAIN" class Job(SQLModel, table=True): diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index 67b288f39..1cb659f85 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -9,6 +9,13 @@ LlmCall, AudioContent, TextContent, + TextInput, + AudioInput, + PromptTemplate, + ChainBlock, + ChainStatus, + LLMChainRequest, + LlmChain, ImageContent, PDFContent, ImageInput, @@ -21,4 +28,6 @@ Usage, TextOutput, AudioOutput, + LLMChainResponse, + IntermediateChainResponse, ) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 22d75d18f..0a8c33818 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,10 +1,13 @@ -import sqlalchemy as sa -from typing import Annotated, Any, List, Literal, Union -from uuid import UUID, uuid4 -from pydantic import model_validator, HttpUrl from datetime import datetime +from enum import Enum +from typing import Annotated, Any, Literal, Union +from uuid import UUID, uuid4 + +import sqlalchemy as sa +from pydantic import HttpUrl, model_validator from sqlalchemy.dialects.postgresql import JSONB -from sqlmodel import Field, SQLModel, Index, text +from sqlmodel import Field, Index, SQLModel, text + from app.core.util import now @@ -248,11 +251,21 @@ class Validator(SQLModel): validator_config_id: UUID +class PromptTemplate(SQLModel): + template: str = Field(..., description="Template string with {{input}} placeholder") + + class ConfigBlob(SQLModel): """Raw JSON blob of config.""" completion: CompletionConfig = Field(..., description="Completion configuration") + # used for llm-chain to provide prompt interpolation + prompt_template: PromptTemplate | None = Field( + default=None, + description="Prompt template with {{input}} placeholder to wrap around the user input", + ) + input_guardrails: list[Validator] | None = Field( default=None, description="Guardrails applied to validate/sanitize the input before the LLM call", @@ -418,6 +431,16 @@ class LlmCall(SQLModel, table=True): }, ) + chain_id: UUID | None = Field( + default=None, + foreign_key="llm_chain.id", + nullable=True, + ondelete="SET NULL", + sa_column_kwargs={ + "comment": "Reference to the parent chain (NULL for standalone llm_call requests)" + }, + ) + # Request fields input: str = Field( ..., @@ -536,3 +559,201 @@ class LlmCall(SQLModel, table=True): nullable=True, sa_column_kwargs={"comment": "Timestamp when the record was soft-deleted"}, ) + + +class ChainBlock(SQLModel): + """A single block in an LLM chain execution.""" + + config: LLMCallConfig = Field( + ..., description="LLM call configuration (stored id+version OR ad-hoc blob)" + ) + + include_provider_raw_response: bool = Field( + default=False, + description="Whether to include the raw LLM provider response in the output for this block", + ) + + intermediate_callback: bool = Field( + default=False, + description="Whether to send intermediate callback after this block completes", + ) + + +class LLMChainRequest(SQLModel): + """ + API request for an LLM chain execution. + + Orchestrates multiple LLM calls sequentially where each block's output + becomes the next block's input. + """ + + query: QueryParams = Field( + ..., description="Initial query input for the first block in the chain" + ) + + blocks: list[ChainBlock] = Field( + ..., min_length=1, description="Ordered list of blocks to execute sequentially" + ) + + callback_url: HttpUrl | None = Field( + default=None, description="Webhook URL for async response delivery" + ) + + request_metadata: dict[str, Any] | None = Field( + default=None, + description=( + "Client-provided metadata passed through unchanged in the response. " + "Use this to correlate responses with requests or track request state. " + "The exact dictionary provided here will be returned in the response metadata field." + ), + ) + + +class ChainStatus(str, Enum): + """Status of an LLM chain execution.""" + + PENDING = "pending" + RUNNING = "running" + FAILED = "failed" + COMPLETED = "completed" + + +class LlmChain(SQLModel, table=True): + """ + Database model for tracking LLM chain execution + + it manages and orchestrates sequential llm_call executions. + """ + + __tablename__ = "llm_chain" + __table_args__ = ( + Index( + "idx_llm_chain_job_id", + "job_id", + ), + ) + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the LLM chain record"}, + ) + + job_id: UUID = Field( + foreign_key="job.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the parent job (status tracked in job table)" + }, + ) + + project_id: int = Field( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the project this LLM call belongs to" + }, + ) + + organization_id: int = Field( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the organization this LLM call belongs to" + }, + ) + + status: ChainStatus = Field( + default=ChainStatus.PENDING, + sa_column_kwargs={ + "comment": "Chain execution status (pending, running, failed, completed)" + }, + ) + + error: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Error message if the chain execution failed"}, + ) + + block_sequences: list[str] | None = Field( + default_factory=list, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + ) + + total_blocks: int = Field( + ..., sa_column_kwargs={"comment": "Total number of blocks to execute"} + ) + + number_of_blocks_processed: int = Field( + default=0, + sa_column_kwargs={ + "comment": "Number of blocks processed so far (used for tracking progress)" + }, + ) + + # Request fields + input: str = Field( + ..., + sa_column_kwargs={ + "comment": "First block user's input - text string, binary data, or file path for multimodal" + }, + ) + + output: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + ) + + configs: list[dict[str, Any]] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + ) + + total_usage: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + ) + + metadata_: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + "metadata", + JSONB, + nullable=True, + comment="Future-proof extensibility catch-all", + ), + ) + + inserted_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the chain record was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the chain record was last updated" + }, + ) diff --git a/backend/app/models/llm/response.py b/backend/app/models/llm/response.py index 780ba2338..9932ec91f 100644 --- a/backend/app/models/llm/response.py +++ b/backend/app/models/llm/response.py @@ -61,3 +61,42 @@ class LLMCallResponse(SQLModel): default=None, description="Unmodified raw response from the LLM provider.", ) + + +class LLMChainResponse(SQLModel): + """Response schema for an LLM chain execution.""" + + response: LLMResponse = Field( + ..., description="LLM response from the final step of the chain execution." + ) + usage: Usage = Field( + ..., + description="Aggregate token usage and cost for the entire chain execution.", + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Raw provider response from the last block (if requested)", + ) + + +class IntermediateChainResponse(SQLModel): + """ + Intermediate callback response from the intermediate blocks + from the llm chain execution. (if configured) + + Flattend structure matching LLMCallResponse keys for consistency + """ + + type: Literal["intermediate"] = "intermediate" + block_index: int = Field(..., description="Current block position") + total_blocks: int = Field(..., description="Total number of blocks in the chain") + response: LLMResponse = Field( + ..., description="LLM Response from the current block" + ) + usage: Usage = Field( + ..., description="Token usage and cost information from the current block" + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Unmodified raw response from the LLM provider from the current block", + ) diff --git a/backend/app/services/llm/chain/__init__.py b/backend/app/services/llm/chain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py new file mode 100644 index 000000000..ad0503675 --- /dev/null +++ b/backend/app/services/llm/chain/chain.py @@ -0,0 +1,137 @@ +import logging +from dataclasses import dataclass, field +from typing import Any, Callable +from uuid import UUID + +from app.models.llm.request import ( + LLMCallConfig, + QueryParams, + TextInput, + TextContent, + AudioInput, +) +from app.models.llm.response import ( + TextOutput, + AudioOutput, + Usage, +) +from app.services.llm.chain.types import BlockResult +from app.services.llm.jobs import execute_llm_call + + +logger = logging.getLogger(__name__) + + +@dataclass +class ChainContext: + """Shared state for chain execution.""" + + job_id: UUID + chain_id: UUID + project_id: int + organization_id: int + callback_url: str | None + total_blocks: int + + langfuse_credentials: dict[str, Any] | None = None + request_metadata: dict | None = None + intermediate_callback_flags: list[bool] = field(default_factory=list) + aggregated_usage: Usage = field( + default_factory=lambda: Usage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + ) + ) + + +def result_to_query(result: BlockResult) -> QueryParams: + """Convert a block's output into the next block's QueryParams. + + Text output → TextInput query + Audio output → AudioInput query + """ + output = result.response.response.output + + if isinstance(output, TextOutput): + return QueryParams( + input=TextInput(content=TextContent(value=output.content.value)) + ) + elif isinstance(output, AudioOutput): + return QueryParams(input=AudioInput(content=output.content)) + else: + raise ValueError(f"Cannot chain output type: {output.type}") + + +class ChainBlock: + """A single block in the chain. Only responsible for executing itself.""" + + def __init__( + self, + *, + config: LLMCallConfig, + index: int, + context: ChainContext, + include_provider_raw_response: bool = False, + ): + self._config = config + self._index = index + self._context = context + self._include_provider_raw_response = include_provider_raw_response + + def execute(self, query: QueryParams) -> BlockResult: + """Execute this block and return the result.""" + logger.info( + f"[ChainBlock.execute] Executing block {self._index} | " + f"job_id={self._context.job_id}" + ) + + return execute_llm_call( + config=self._config, + query=query, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + request_metadata=self._context.request_metadata, + langfuse_credentials=self._context.langfuse_credentials, + include_provider_raw_response=self._include_provider_raw_response, + chain_id=self._context.chain_id, + ) + + +class LLMChain: + """Orchestrates sequential execution of ChainBlocks.""" + + def __init__(self, blocks: list[ChainBlock], context: ChainContext): + self._blocks = blocks + self._context = context + + def execute( + self, + query: QueryParams, + on_block_completed: Callable[[int, BlockResult], None] | None = None, + ) -> BlockResult: + """Execute blocks sequentially, passing output of each to the next.""" + if not self._blocks: + return BlockResult(error="Chain has no blocks") + + current_query = query + result: BlockResult | None = None + + for block in self._blocks: + result = block.execute(current_query) + + if on_block_completed: + on_block_completed(block._index, result) + + if not result.success: + logger.error( + f"[LLMChain.execute] Block {block._index} failed: {result.error} | " + f"job_id={self._context.job_id}" + ) + return result + + if block is not self._blocks[-1]: + current_query = result_to_query(result) + + return result diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py new file mode 100644 index 000000000..27ab8de86 --- /dev/null +++ b/backend/app/services/llm/chain/executor.py @@ -0,0 +1,186 @@ +import logging + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.jobs import JobCrud +from app.crud.llm_chain import update_llm_chain_block_completed, update_llm_chain_status +from app.models import JobStatus, JobUpdate +from app.models.llm.request import ( + ChainStatus, + LLMChainRequest, +) +from app.models.llm.response import IntermediateChainResponse, LLMChainResponse +from app.services.llm.chain.chain import ChainContext, LLMChain +from app.services.llm.chain.types import BlockResult +from app.utils import APIResponse, send_callback + +logger = logging.getLogger(__name__) + + +class ChainExecutor: + """Manage the lifecycle of an LLM chain execution.""" + + def __init__( + self, + *, + chain: LLMChain, + context: ChainContext, + request: LLMChainRequest, + ): + self._chain = chain + self._context = context + self._request = request + + def run(self) -> dict: + """Execute the full chain lifecycle. Returns serialized APIResponse.""" + try: + self._setup() + + result = self._chain.execute( + self._request.query, + on_block_completed=self._on_block_completed, + ) + + return self._teardown(result) + + except Exception as e: + return self._handle_unexpected_error(e) + + def _setup(self) -> None: + with Session(engine) as session: + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.PROCESSING), + ) + + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.RUNNING, + ) + + def _teardown(self, result: BlockResult) -> dict: + """Finalize chain record, send callback, and update job status.""" + + with Session(engine) as session: + if result.success: + final = LLMChainResponse( + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_response = APIResponse.success_response( + data=final, metadata=self._request.request_metadata + ) + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.SUCCESS), + ) + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.COMPLETED, + output=result.response.response.output.model_dump(), + total_usage=self._context.aggregated_usage.model_dump(), + ) + return callback_response.model_dump() + else: + return self._handle_error(result.error) + + def _handle_error(self, error: str) -> dict: + callback_response = APIResponse.failure_response( + error=error or "Unknown error occurred", + metadata=self._request.request_metadata, + ) + logger.error( + f"[_handle_error] Chain execution failed | " + f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}" + ) + + with Session(engine) as session: + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + + update_llm_chain_status( + session, + chain_id=self._context.chain_id, + status=ChainStatus.FAILED, + output=None, + total_usage=self._context.aggregated_usage.model_dump(), + error=error, + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.FAILED, error_message=error), + ) + return callback_response.model_dump() + + def _on_block_completed(self, block_index: int, result: BlockResult) -> None: + """Handle side effects after each block completes.""" + if result.usage: + self._context.aggregated_usage.input_tokens += result.usage.input_tokens + self._context.aggregated_usage.output_tokens += result.usage.output_tokens + self._context.aggregated_usage.total_tokens += result.usage.total_tokens + + if result.success and result.llm_call_id: + with Session(engine) as session: + update_llm_chain_block_completed( + session, + chain_id=self._context.chain_id, + llm_call_id=result.llm_call_id, + ) + + if ( + block_index < len(self._context.intermediate_callback_flags) + and self._context.intermediate_callback_flags[block_index] + and self._request.callback_url + and block_index < self._context.total_blocks - 1 + ): + self._send_intermediate_callback(block_index, result) + + def _send_intermediate_callback( + self, block_index: int, result: BlockResult + ) -> None: + """Send intermediate callback for a completed block.""" + try: + intermediate = IntermediateChainResponse( + block_index=block_index + 1, + total_blocks=self._context.total_blocks, + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_data = APIResponse.success_response( + data=intermediate, + metadata=self._context.request_metadata, + ) + send_callback( + callback_url=str(self._request.callback_url), + data=callback_data.model_dump(), + ) + logger.info( + f"[_send_intermediate_callback] Sent intermediate callback | " + f"block={block_index + 1}/{self._context.total_blocks}, job_id={self._context.job_id}" + ) + except Exception as e: + logger.warning( + f"[_send_intermediate_callback] Failed to send intermediate callback: {e} | " + f"block={block_index + 1}/{self._context.total_blocks}, job_id={self._context.job_id}" + ) + + def _handle_unexpected_error(self, e: Exception) -> dict: + logger.error( + f"[ChainExecutor.run] Unexpected error: {e} | " + f"job_id={self._context.job_id}", + exc_info=True, + ) + return self._handle_error("Unexpected error occurred") diff --git a/backend/app/services/llm/chain/types.py b/backend/app/services/llm/chain/types.py new file mode 100644 index 000000000..7fa0f39d8 --- /dev/null +++ b/backend/app/services/llm/chain/types.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from uuid import UUID + +from app.models.llm.response import LLMCallResponse, Usage + + +@dataclass +class BlockResult: + """Result of a single block/LLM call execution.""" + + response: LLMCallResponse | None = None + llm_call_id: UUID | None = None + usage: Usage | None = None + error: str | None = None + metadata: dict | None = None + + @property + def success(self) -> bool: + return self.error is None and self.response is not None diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 5cdc0d32b..2a5f7dee2 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from typing import Any from uuid import UUID + from asgi_correlation_id import correlation_id from fastapi import HTTPException from sqlmodel import Session @@ -12,25 +13,29 @@ from app.crud.config import ConfigVersionCrud from app.crud.credentials import get_provider_credential from app.crud.jobs import JobCrud -from app.crud.llm import create_llm_call, update_llm_call_response -from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, Job +from app.crud.llm import create_llm_call, serialize_input, update_llm_call_response +from app.crud.llm_chain import create_llm_chain, update_llm_chain_status +from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest from app.models.llm.request import ( - ConfigBlob, - LLMCallConfig, - KaapiCompletionConfig, - TextInput, AudioInput, + ChainStatus, + ConfigBlob, ImageInput, + KaapiCompletionConfig, + LLMCallConfig, PDFInput, + QueryParams, + TextInput, ) from app.models.llm.response import TextOutput +from app.services.llm.chain.types import BlockResult from app.services.llm.guardrails import ( list_validators_config, run_guardrails_validation, ) -from app.services.llm.providers.registry import get_llm_provider from app.services.llm.mappers import transform_kaapi_config_to_native -from app.utils import APIResponse, send_callback, resolve_input, cleanup_temp_file +from app.services.llm.providers.registry import get_llm_provider +from app.utils import APIResponse, cleanup_temp_file, resolve_input, send_callback logger = logging.getLogger(__name__) @@ -77,6 +82,49 @@ def start_job( return job.id +def start_chain_job( + db: Session, request: LLMChainRequest, project_id: int, organization_id: int +) -> UUID: + """Create an LLM Chain job and schedule Celery task.""" + trace_id = correlation_id.get() or "N/A" + job_crud = JobCrud(session=db) + job = job_crud.create(job_type=JobType.LLM_CHAIN, trace_id=trace_id) + + # Explicitly flush to ensure job is persisted before Celery task starts + db.flush() + db.commit() + + logger.info( + f"[start_chain_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}" + ) + + try: + task_id = start_high_priority_job( + function_path="app.services.llm.jobs.execute_chain_job", + project_id=project_id, + job_id=str(job.id), + trace_id=trace_id, + request_data=request.model_dump(mode="json"), + organization_id=organization_id, + ) + except Exception as e: + logger.error( + f"[start_chain_job] Error starting Celery task: {str(e)} | job_id={job.id}, project_id={project_id}", + exc_info=True, + ) + job_update = JobUpdate(status=JobStatus.FAILED, error_message=str(e)) + job_crud.update(job_id=job.id, job_update=job_update) + raise HTTPException( + status_code=500, + detail="Internal server error while executing LLM chain job", + ) + + logger.info( + f"[start_chain_job] Job scheduled for LLM chain job | job_id={job.id}, project_id={project_id}, task_id={task_id}" + ) + return job.id + + def handle_job_error( job_id: UUID, callback_url: str | None, @@ -125,55 +173,6 @@ def resolved_input_context( cleanup_temp_file(resolved_input) -def validate_text_with_guardrails( - text: str, - guardrails: list[dict[str, Any]], - job_id: UUID, - project_id: int, - organization_id: int, - guardrail_type: str, # "input" or "output" -) -> tuple[str | None, str | None]: - """Validate text against guardrails. - - Returns: - (validated_text, error_message) - - If successful: (modified_text, None) - - If failed: (None, error_message) - - If bypassed: (original_text, None) - """ - safe_result = run_guardrails_validation( - text, - guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) - - logger.info( - f"[validate_text_with_guardrails] {guardrail_type.capitalize()} guardrail validation | " - f"success={safe_result['success']}, job_id={job_id}" - ) - - if safe_result.get("bypassed"): - logger.info( - f"[validate_text_with_guardrails] Guardrails bypassed (service unavailable) | " - f"job_id={job_id}" - ) - return text, None - - if safe_result["success"]: - validated_text = safe_result["data"]["safe_text"] - - # Special case for output guardrails: check if rephrase is needed - if guardrail_type == "output" and safe_result["data"].get("rephrase_needed"): - return None, "Output requires rephrasing" - - return validated_text, None - - return None, safe_result["error"] - - def resolve_config_blob( config_crud: ConfigVersionCrud, config: LLMCallConfig ) -> tuple[ConfigBlob | None, str | None]: @@ -209,182 +208,228 @@ def resolve_config_blob( return None, "Unexpected error occurred while parsing stored configuration" -def execute_job( - request_data: dict, +def apply_input_guardrails( + *, + config_blob: ConfigBlob | None, + query: QueryParams, + job_id: UUID, project_id: int, organization_id: int, - job_id: str, - task_id: str, - task_instance, -) -> dict: - """Celery task to process an LLM request asynchronously. +) -> tuple[QueryParams, str | None]: + """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.""" + if not config_blob or not config_blob.input_guardrails: + return query, None - Returns: - dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure + if not isinstance(query.input, TextInput): + logger.info( + f"[apply_input_guardrails] Skipping for non-text input. " + f"job_id={job_id}, " + f"input_type={getattr(query.input, 'type', type(query.input).__name__)}" + ) + return query, None + + input_guardrails, _ = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=config_blob.input_guardrails, + output_validator_configs=None, + ) + + if not input_guardrails: + return query, None + + safe = run_guardrails_validation( + query.input.content.value, + input_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) + + logger.info( + f"[apply_input_guardrails] Validation result | success={safe['success']}, job_id={job_id}" + ) + + if safe.get("bypassed"): + logger.info( + f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return query, None + + if safe["success"]: + query.input.content.value = safe["data"]["safe_text"] + return query, None + + return query, safe["error"] + + +def apply_output_guardrails( + *, + config_blob: ConfigBlob | None, + result: BlockResult, + job_id: UUID, + project_id: int, + organization_id: int, +) -> tuple[BlockResult, str | None]: + """Apply output guardrails from a config_blob. Shared by /llm/call and /llm/chain. + + Returns (modified_result, None) on success, or (result, error_string) on failure. """ + if not config_blob or not config_blob.output_guardrails: + return result, None - request = LLMCallRequest(**request_data) - job_uuid = UUID(job_id) # Renamed to avoid shadowing parameter - callback_url_str = str(request.callback_url) if request.callback_url else None + if not isinstance(result.response.response.output, TextOutput): + logger.info( + f"[apply_output_guardrails] Skipping for non-text output. " + f"job_id={job_id}, " + f"output_type={getattr(result.response.response.output, 'type', type(result.response.response.output).__name__)}" + ) + return result, None - config = request.config - callback_response = None - config_blob: ConfigBlob | None = None - input_guardrails: list[dict] = [] - output_guardrails: list[dict] = [] - llm_call_id: UUID | None = None # Track the LLM call record + _, output_guardrails = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=None, + output_validator_configs=config_blob.output_guardrails, + ) + + if not output_guardrails: + return result, None + + output_text = result.response.response.output.content.value + safe = run_guardrails_validation( + output_text, + output_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) logger.info( - f"[execute_job] Starting LLM job execution | job_id={job_uuid}, task_id={task_id}" + f"[apply_output_guardrails] Validation result | success={safe['success']}, job_id={job_id}" ) + if safe.get("bypassed"): + logger.info( + f"[apply_output_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return result, None + + if safe["success"]: + result.response.response.output.content.value = safe["data"]["safe_text"] + return result, None + + return result, safe["error"] + + +def execute_llm_call( + *, + config: LLMCallConfig, + query: QueryParams, + job_id: UUID, + project_id: int, + organization_id: int, + request_metadata: dict | None, + langfuse_credentials: dict | None, + include_provider_raw_response: bool = False, + chain_id: UUID | None = None, +) -> BlockResult: + """Execute a single LLM call. Shared by /llm/call and /llm/chain. + + Returns BlockResult with response + usage on success, or error on failure. + """ + + config_blob: ConfigBlob | None = None + llm_call_id: UUID | None = None + try: with Session(engine) as session: - # Update job status to PROCESSING - job_crud = JobCrud(session=session) - job_crud.update( - job_id=job_uuid, job_update=JobUpdate(status=JobStatus.PROCESSING) - ) - - # if stored config, fetch blob from DB if config.is_stored_config: config_crud = ConfigVersionCrud( session=session, project_id=project_id, config_id=config.id ) - - # blob is dynamic, need to resolve to ConfigBlob format config_blob, error = resolve_config_blob(config_crud, config) - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_uuid, callback_url_str, callback_response - ) - + return BlockResult(error=error) else: config_blob = config.blob - if config_blob is not None: - if config_blob.input_guardrails or config_blob.output_guardrails: - input_guardrails, output_guardrails = list_validators_config( - organization_id=organization_id, - project_id=project_id, - input_validator_configs=config_blob.input_guardrails, - output_validator_configs=config_blob.output_guardrails, - ) - - if input_guardrails: - if not isinstance(request.query.input, TextInput): - logger.info( - "[execute_job] Skipping input guardrails for non-text input. " - f"job_id={job_uuid}, input_type={getattr(request.query.input, 'type', type(request.query.input).__name__)}" - ) - else: - validated_text, error = validate_text_with_guardrails( - request.query.input.content.value, - input_guardrails, - job_uuid, - project_id, - organization_id, - guardrail_type="input", - ) - - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_uuid, callback_url_str, callback_response - ) + if config_blob.prompt_template and isinstance(query.input, TextInput): + template = config_blob.prompt_template.template + interpolated = template.replace("{{input}}", query.input.content.value) + query.input.content.value = interpolated - # Update input with validated text - request.query.input.content.value = validated_text - try: - # Transform Kaapi config to native config if needed (before getting provider) - completion_config = config_blob.completion - - original_provider = ( - config_blob.completion.provider - ) # openai, google or prefixed - - if isinstance(completion_config, KaapiCompletionConfig): - completion_config, warnings = transform_kaapi_config_to_native( - completion_config - ) + query, input_error = apply_input_guardrails( + config_blob=config_blob, + query=query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if input_error: + return BlockResult(error=input_error) - if request.request_metadata is None: - request.request_metadata = {} - request.request_metadata.setdefault("warnings", []).extend(warnings) + completion_config = config_blob.completion + original_provider = completion_config.provider - except Exception as e: - callback_response = APIResponse.failure_response( - error=f"Error processing configuration: {str(e)}", - metadata=request.request_metadata, + if isinstance(completion_config, KaapiCompletionConfig): + completion_config, warnings = transform_kaapi_config_to_native( + completion_config ) - return handle_job_error(job_uuid, callback_url_str, callback_response) + if request_metadata is None: + request_metadata = {} + request_metadata.setdefault("warnings", []).extend(warnings) + + resolved_config_blob = ConfigBlob( + completion=completion_config, + prompt_template=config_blob.prompt_template, + input_guardrails=config_blob.input_guardrails, + output_guardrails=config_blob.output_guardrails, + ) - # Create LLM call record before execution try: - # Rebuild ConfigBlob with transformed native config - resolved_config_blob = ConfigBlob( - completion=completion_config, - input_guardrails=config_blob.input_guardrails, - output_guardrails=config_blob.output_guardrails, + llm_call_request = LLMCallRequest( + query=query, + config=config, + request_metadata=request_metadata, ) - llm_call = create_llm_call( session, - request=request, - job_id=job_uuid, + request=llm_call_request, + job_id=job_id, project_id=project_id, organization_id=organization_id, resolved_config=resolved_config_blob, original_provider=original_provider, + chain_id=chain_id, ) llm_call_id = llm_call.id logger.info( - f"[execute_job] Created LLM call record | llm_call_id={llm_call_id}, job_id={job_uuid}" + f"[execute_llm_call] Created LLM call record | " + f"llm_call_id={llm_call_id}, job_id={job_id}" ) except Exception as e: logger.error( - f"[execute_job] Failed to create LLM call record: {str(e)} | job_id={job_uuid}", + f"[execute_llm_call] Failed to create LLM call record: {e} | job_id={job_id}", exc_info=True, ) - callback_response = APIResponse.failure_response( - error=f"Failed to create LLM call record: {str(e)}", - metadata=request.request_metadata, - ) - return handle_job_error(job_uuid, callback_url_str, callback_response) + return BlockResult(error=f"Failed to create LLM call record: {str(e)}") try: provider_instance = get_llm_provider( session=session, - provider_type=completion_config.provider, # Now always native provider type i.e openai-native, google-native regardless + provider_type=completion_config.provider, project_id=project_id, organization_id=organization_id, ) except ValueError as ve: - callback_response = APIResponse.failure_response( - error=str(ve), - metadata=request.request_metadata, - ) - return handle_job_error(job_uuid, callback_url_str, callback_response) + return BlockResult(error=str(ve), llm_call_id=llm_call_id) - langfuse_credentials = get_provider_credential( - session=session, - org_id=organization_id, - project_id=project_id, - provider="langfuse", - ) - - # Extract conversation_id for langfuse session grouping conversation_id = None - if request.query.conversation and request.query.conversation.id: - conversation_id = request.query.conversation.id + if query.conversation and query.conversation.id: + conversation_id = query.conversation.id # Apply Langfuse observability decorator to provider execute method decorated_execute = observe_llm_execution( @@ -394,63 +439,18 @@ def execute_job( # Resolve input and execute LLM (context manager handles cleanup) try: - with resolved_input_context(request.query.input) as resolved_input: + with resolved_input_context(query.input) as resolved_input: response, error = decorated_execute( completion_config=completion_config, - query=request.query, + query=query, resolved_input=resolved_input, - include_provider_raw_response=request.include_provider_raw_response, + include_provider_raw_response=include_provider_raw_response, ) except ValueError as ve: - # Handle input resolution errors from context manager - callback_response = APIResponse.failure_response( - error=str(ve), - metadata=request.request_metadata, - ) - return handle_job_error(job_uuid, callback_url_str, callback_response) + return BlockResult(error=str(ve), llm_call_id=llm_call_id) if response: - if output_guardrails: - if not isinstance(response.response.output, TextOutput): - logger.info( - "[execute_job] Skipping output guardrails for non-text output. " - f"job_id={job_uuid}, output_type={getattr(response.response.output, 'type', type(response.response.output).__name__)}" - ) - else: - output_text = response.response.output.content.value - validated_text, error = validate_text_with_guardrails( - output_text, - output_guardrails, - job_uuid, - project_id, - organization_id, - guardrail_type="output", - ) - - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_uuid, callback_url_str, callback_response - ) - - # Update output with validated text - response.response.output.content.value = validated_text - callback_response = APIResponse.success_response( - data=response, metadata=request.request_metadata - ) - if callback_url_str: - send_callback( - callback_url=callback_url_str, - data=callback_response.model_dump(), - ) - with Session(engine) as session: - job_crud = JobCrud(session=session) - - # Update LLM call record with response data if llm_call_id: try: update_llm_call_response( @@ -461,35 +461,122 @@ def execute_job( usage=response.usage.model_dump(), conversation_id=response.response.conversation_id, ) - logger.info( - f"[execute_job] Updated LLM call record | llm_call_id={llm_call_id}" - ) except Exception as e: logger.error( - f"[execute_job] Failed to update LLM call record: {str(e)} | llm_call_id={llm_call_id}", + f"[execute_llm_call] Failed to update LLM call record: {e} | " + f"llm_call_id={llm_call_id}", exc_info=True, ) - # Don't fail the job if updating the record fails + result = BlockResult( + response=response, + llm_call_id=llm_call_id, + usage=response.usage, + metadata=request_metadata, + ) - job_crud.update( + result, output_error = apply_output_guardrails( + config_blob=config_blob, + result=result, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if output_error: + return BlockResult(error=output_error, llm_call_id=llm_call_id) + + return result + + return BlockResult( + error=error or "Unknown error occurred", + llm_call_id=llm_call_id, + ) + + except Exception as e: + logger.error( + f"[execute_llm_call] Unexpected error: {e} | job_id={job_id}", + exc_info=True, + ) + return BlockResult( + error="Unexpected error occurred", + llm_call_id=llm_call_id, + ) + + +def execute_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task to process an LLM request asynchronously. + + Returns: + dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure + """ + request = LLMCallRequest(**request_data) + job_uuid = UUID(job_id) # Renamed to avoid shadowing parameter + callback_url_str = str(request.callback_url) if request.callback_url else None + + logger.info( + f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}" + ) + + try: + with Session(engine) as session: + job_crud = JobCrud(session=session) + job_crud.update( + job_id=job_uuid, job_update=JobUpdate(status=JobStatus.PROCESSING) + ) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + result = execute_llm_call( + config=request.config, + query=request.query, + job_id=job_uuid, + project_id=project_id, + organization_id=organization_id, + request_metadata=request.request_metadata, + langfuse_credentials=langfuse_credentials, + include_provider_raw_response=request.include_provider_raw_response, + ) + + if result.success: + callback_response = APIResponse.success_response( + data=result.response, metadata=result.metadata + ) + if callback_url_str: + send_callback( + callback_url=callback_url_str, + data=callback_response.model_dump(), + ) + + with Session(engine) as session: + JobCrud(session=session).update( job_id=job_uuid, job_update=JobUpdate(status=JobStatus.SUCCESS) ) logger.info( - f"[execute_job] Successfully completed LLM job | job_id={job_uuid}, " - f"provider_response_id={response.response.provider_response_id}, tokens={response.usage.total_tokens}" + f"[execute_job] Successfully completed LLM job | job_id={job_id}, " + f"tokens={result.usage.total_tokens}" ) return callback_response.model_dump() callback_response = APIResponse.failure_response( - error=error or "Unknown error occurred", + error=result.error or "Unknown error occurred", metadata=request.request_metadata, ) return handle_job_error(job_uuid, callback_url_str, callback_response) except Exception as e: - error_type = type(e).__name__ callback_response = APIResponse.failure_response( - error=f"Unexpected error during LLM execution: {error_type}", + error="Unexpected error occurred", metadata=request.request_metadata, ) logger.error( @@ -497,3 +584,113 @@ def execute_job( exc_info=True, ) return handle_job_error(job_uuid, callback_url_str, callback_response) + + +def execute_chain_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task to process an LLM Chain request asynchronously. + + Returns: + dict: Serialized APIResponse[LLMChainResponse] on success, APIResponse[None] on failure + """ + # imports to avoid circular dependency: + from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain + from app.services.llm.chain.executor import ChainExecutor + + request = LLMChainRequest(**request_data) + job_uuid = UUID(job_id) + callback_url_str = str(request.callback_url) if request.callback_url else None + chain_uuid = None + + logger.info( + f"[execute_chain_job] Starting chain execution | " + f"job_id={job_uuid}, total_blocks={len(request.blocks)}" + ) + + try: + with Session(engine) as session: + chain_record = create_llm_chain( + session, + job_id=job_uuid, + project_id=project_id, + organization_id=organization_id, + total_blocks=len(request.blocks), + input=serialize_input(request.query.input), + configs=[block.model_dump(mode="json") for block in request.blocks], + ) + chain_uuid = chain_record.id + + logger.info( + f"[execute_chain_job] Created chain record | " + f"chain_id={chain_uuid}, job_id={job_uuid}" + ) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + context = ChainContext( + job_id=job_uuid, + chain_id=chain_uuid, + project_id=project_id, + organization_id=organization_id, + langfuse_credentials=langfuse_credentials, + request_metadata=request.request_metadata, + total_blocks=len(request.blocks), + callback_url=str(request.callback_url) if request.callback_url else None, + intermediate_callback_flags=[ + block.intermediate_callback for block in request.blocks + ], + ) + + blocks = [ + ChainBlock( + config=block.config, + index=i, + context=context, + include_provider_raw_response=block.include_provider_raw_response, + ) + for i, block in enumerate(request.blocks) + ] + + chain = LLMChain(blocks, context) + + executor = ChainExecutor(chain=chain, context=context, request=request) + return executor.run() + + except Exception as e: + logger.error( + f"[execute_chain_job] Failed: {e} | job_id={job_uuid}", + exc_info=True, + ) + + if chain_uuid: + try: + with Session(engine) as session: + update_llm_chain_status( + session, + chain_id=chain_uuid, + status=ChainStatus.FAILED, + error=str(e), + ) + except Exception: + logger.error( + f"[execute_chain_job] Failed to update chain status: {e} | " + f"chain_id={chain_uuid}", + exc_info=True, + ) + + callback_response = APIResponse.failure_response( + error="Unexpected error occurred", + metadata=request.request_metadata, + ) + return handle_job_error(job_uuid, callback_url_str, callback_response) diff --git a/backend/app/tests/crud/test_llm_chain.py b/backend/app/tests/crud/test_llm_chain.py new file mode 100644 index 000000000..dfeceeee4 --- /dev/null +++ b/backend/app/tests/crud/test_llm_chain.py @@ -0,0 +1,150 @@ +import pytest +from uuid import uuid4 + +from sqlmodel import Session + +from app.crud import JobCrud +from app.crud.llm_chain import ( + create_llm_chain, + update_llm_chain_status, + update_llm_chain_block_completed, +) +from app.models import JobType +from app.models.llm.request import ChainStatus +from app.tests.utils.utils import get_project + + +class TestCreateLlmChain: + def test_creates_chain_record(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=3, + input="Test input", + configs=[{"completion": {"provider": "openai-native"}}], + ) + + assert chain.id is not None + assert chain.job_id == job.id + assert chain.project_id == project.id + assert chain.status == ChainStatus.PENDING + assert chain.total_blocks == 3 + assert chain.number_of_blocks_processed == 0 + assert chain.input == "Test input" + assert chain.block_sequences == [] + + +class TestUpdateLlmChainStatus: + @pytest.fixture + def chain(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=2, + input="hello", + configs=[], + ) + return chain + + def test_update_to_running(self, db: Session, chain): + updated = update_llm_chain_status( + db, chain_id=chain.id, status=ChainStatus.RUNNING + ) + + assert updated.status == ChainStatus.RUNNING + + def test_update_to_completed(self, db: Session, chain): + output = {"type": "text", "content": {"value": "result"}} + usage = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30} + + updated = update_llm_chain_status( + db, + chain_id=chain.id, + status=ChainStatus.COMPLETED, + output=output, + total_usage=usage, + ) + + assert updated.status == ChainStatus.COMPLETED + assert updated.output == output + assert updated.total_usage == usage + + def test_update_to_failed(self, db: Session, chain): + usage = {"input_tokens": 5, "output_tokens": 0, "total_tokens": 5} + + updated = update_llm_chain_status( + db, + chain_id=chain.id, + status=ChainStatus.FAILED, + error="Provider timeout", + total_usage=usage, + ) + + assert updated.status == ChainStatus.FAILED + assert updated.error == "Provider timeout" + assert updated.total_usage == usage + + def test_raises_for_missing_chain(self, db: Session): + with pytest.raises(ValueError, match="LLM chain not found"): + update_llm_chain_status(db, chain_id=uuid4(), status=ChainStatus.RUNNING) + + +class TestUpdateLlmChainBlockCompleted: + @pytest.fixture + def chain(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=3, + input="hello", + configs=[], + ) + return chain + + def test_appends_llm_call_id(self, db: Session, chain): + call_id = uuid4() + + updated = update_llm_chain_block_completed( + db, chain_id=chain.id, llm_call_id=call_id + ) + + assert str(call_id) in updated.block_sequences + assert updated.number_of_blocks_processed == 1 + + def test_appends_multiple_blocks(self, db: Session, chain): + call_id_1 = uuid4() + call_id_2 = uuid4() + + update_llm_chain_block_completed(db, chain_id=chain.id, llm_call_id=call_id_1) + updated = update_llm_chain_block_completed( + db, chain_id=chain.id, llm_call_id=call_id_2 + ) + + assert len(updated.block_sequences) == 2 + assert updated.number_of_blocks_processed == 2 + + def test_raises_for_missing_chain(self, db: Session): + with pytest.raises(ValueError, match="LLM chain not found"): + update_llm_chain_block_completed(db, chain_id=uuid4(), llm_call_id=uuid4()) diff --git a/backend/app/tests/services/llm/test_chain.py b/backend/app/tests/services/llm/test_chain.py new file mode 100644 index 000000000..5b5cfed3f --- /dev/null +++ b/backend/app/tests/services/llm/test_chain.py @@ -0,0 +1,219 @@ +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest + +from app.models.llm.request import ( + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + TextInput, + TextContent, + AudioInput, +) +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + Usage, + TextOutput, + TextContent as ResponseTextContent, + AudioOutput, + AudioContent, +) +from app.services.llm.chain.chain import ( + ChainBlock, + ChainContext, + LLMChain, + result_to_query, +) +from app.services.llm.chain.types import BlockResult + + +@pytest.fixture +def context(): + return ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url="https://example.com/callback", + total_blocks=3, + intermediate_callback_flags=[True, True, False], + ) + + +@pytest.fixture +def text_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-1", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=ResponseTextContent(value="Hello world")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + +@pytest.fixture +def audio_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-2", + conversation_id=None, + model="gemini", + provider="google", + output=AudioOutput( + content=AudioContent( + format="base64", + value="audio-data-base64", + mime_type="audio/wav", + ) + ), + ), + usage=Usage(input_tokens=5, output_tokens=15, total_tokens=20), + provider_raw_response=None, + ) + + +def make_config(): + return LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + + +class TestResultToQuery: + def test_text_output_to_query(self, text_response): + result = BlockResult(response=text_response, usage=text_response.usage) + + query = result_to_query(result) + + assert isinstance(query.input, TextInput) + assert query.input.content.value == "Hello world" + + def test_audio_output_to_query(self, audio_response): + result = BlockResult(response=audio_response, usage=audio_response.usage) + + query = result_to_query(result) + + assert isinstance(query.input, AudioInput) + assert query.input.content.value == "audio-data-base64" + + def test_unsupported_output_type_raises(self): + mock_response = MagicMock() + mock_response.response.output.type = "unknown" + mock_response.response.output.__class__ = type("Unknown", (), {}) + result = BlockResult(response=mock_response, usage=MagicMock()) + + with pytest.raises(ValueError, match="Cannot chain output type"): + result_to_query(result) + + +class TestChainBlock: + def test_execute_single_block(self, context, text_response): + query = QueryParams(input="test input") + config = make_config() + block = ChainBlock(config=config, index=0, context=context) + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = block.execute(query) + + assert result.success + mock_execute.assert_called_once() + + def test_execute_returns_failure(self, context): + query = QueryParams(input="test input") + config = make_config() + block = ChainBlock(config=config, index=0, context=context) + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult(error="Provider error") + + result = block.execute(query) + + assert not result.success + assert result.error == "Provider error" + mock_execute.assert_called_once() + + +class TestLLMChain: + def test_execute_empty_chain(self, context): + chain = LLMChain([], context) + query = QueryParams(input="test") + + result = chain.execute(query) + + assert not result.success + assert result.error == "Chain has no blocks" + + def test_execute_single_block_chain(self, context, text_response): + config = make_config() + block = ChainBlock(config=config, index=0, context=context) + chain = LLMChain([block], context) + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = chain.execute(QueryParams(input="hello")) + + assert result.success + mock_execute.assert_called_once() + + def test_execute_multi_block_chain(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] + chain = LLMChain(blocks, context) + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = chain.execute(QueryParams(input="hello")) + + assert result.success + assert mock_execute.call_count == 3 + + def test_execute_stops_on_failure(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] + chain = LLMChain(blocks, context) + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult(error="Provider error") + + result = chain.execute(QueryParams(input="hello")) + + assert not result.success + assert result.error == "Provider error" + mock_execute.assert_called_once() + + def test_execute_calls_on_block_completed(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(2)] + chain = LLMChain(blocks, context) + callback = MagicMock() + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + chain.execute(QueryParams(input="hello"), on_block_completed=callback) + + assert callback.call_count == 2 diff --git a/backend/app/tests/services/llm/test_chain_executor.py b/backend/app/tests/services/llm/test_chain_executor.py new file mode 100644 index 000000000..6564ebafb --- /dev/null +++ b/backend/app/tests/services/llm/test_chain_executor.py @@ -0,0 +1,381 @@ +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest + +from app.models.llm.request import ( + LLMChainRequest, + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + ChainStatus, +) +from app.models.llm.request import ChainBlock as ChainBlockModel +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + Usage, + TextOutput, + TextContent, +) +from app.models import JobStatus +from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain +from app.services.llm.chain.executor import ChainExecutor +from app.services.llm.chain.types import BlockResult + + +@pytest.fixture +def context(): + return ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url="https://example.com/callback", + total_blocks=1, + ) + + +@pytest.fixture +def request_obj(): + return LLMChainRequest( + query=QueryParams(input="hello"), + blocks=[ + ChainBlockModel( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + ) + ], + callback_url="https://example.com/callback", + ) + + +@pytest.fixture +def text_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-1", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=TextContent(value="Response text")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + +@pytest.fixture +def success_result(text_response): + return BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + ) + + +@pytest.fixture +def failure_result(): + return BlockResult(error="Provider failed") + + +class TestChainExecutor: + def _make_executor(self, context, request_obj, chain_result): + mock_chain = MagicMock(spec=LLMChain) + mock_chain.execute.return_value = chain_result + return ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + def test_run_success_with_callback(self, context, request_obj, success_result): + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is True + mock_callback.assert_called_once() + # Verify chain status updated to COMPLETED + completed_call = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.COMPLETED + ] + assert len(completed_call) == 1 + + def test_run_success_without_callback(self, context, request_obj, success_result): + request_obj.callback_url = None + context.callback_url = None + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is True + mock_callback.assert_not_called() + + def test_run_failure_updates_status(self, context, request_obj, failure_result): + executor = self._make_executor(context, request_obj, failure_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is False + assert result["error"] == "Provider failed" + # Verify chain status updated to FAILED + failed_call = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.FAILED + ] + assert len(failed_call) == 1 + + def test_run_failure_sends_callback(self, context, request_obj, failure_result): + executor = self._make_executor(context, request_obj, failure_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + mock_callback.assert_called_once() + + def test_run_unexpected_exception_handled(self, context, request_obj): + mock_chain = MagicMock(spec=LLMChain) + mock_chain.execute.side_effect = RuntimeError("Something broke") + executor = ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is False + assert "Unexpected error occurred" in result["error"] + + def test_setup_updates_job_and_chain_status( + self, context, request_obj, success_result + ): + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + patch("app.services.llm.chain.executor.JobCrud") as mock_job_crud, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + executor.run() + + # _setup should set chain to RUNNING + running_calls = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.RUNNING + ] + assert len(running_calls) == 1 + + +class TestOnBlockCompleted: + def _make_executor(self, context, request_obj): + mock_chain = MagicMock(spec=LLMChain) + return ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + def test_aggregates_usage(self, context, request_obj): + executor = self._make_executor(context, request_obj) + usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + result = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage, error=None + ) + + with patch("app.services.llm.chain.executor.Session"): + executor._on_block_completed(0, result) + + assert context.aggregated_usage.input_tokens == 10 + assert context.aggregated_usage.output_tokens == 20 + assert context.aggregated_usage.total_tokens == 30 + + def test_aggregates_usage_across_blocks(self, context, request_obj): + executor = self._make_executor(context, request_obj) + result1 = BlockResult( + response=MagicMock(), + llm_call_id=uuid4(), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + error=None, + ) + result2 = BlockResult( + response=MagicMock(), + llm_call_id=uuid4(), + usage=Usage(input_tokens=5, output_tokens=15, total_tokens=20), + error=None, + ) + + with patch("app.services.llm.chain.executor.Session"): + executor._on_block_completed(0, result1) + executor._on_block_completed(1, result2) + + assert context.aggregated_usage.input_tokens == 15 + assert context.aggregated_usage.total_tokens == 50 + + def test_updates_db_on_success(self, context, request_obj): + executor = self._make_executor(context, request_obj) + llm_call_id = uuid4() + result = BlockResult( + response=MagicMock(), llm_call_id=llm_call_id, usage=MagicMock(), error=None + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch( + "app.services.llm.chain.executor.update_llm_chain_block_completed" + ) as mock_update, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_update.assert_called_once_with( + mock_session.return_value.__enter__.return_value, + chain_id=context.chain_id, + llm_call_id=llm_call_id, + ) + + def test_skips_db_update_on_error(self, context, request_obj): + executor = self._make_executor(context, request_obj) + result = BlockResult(error="Block failed", usage=MagicMock()) + + with patch( + "app.services.llm.chain.executor.update_llm_chain_block_completed" + ) as mock_update: + executor._on_block_completed(0, result) + mock_update.assert_not_called() + + def test_sends_intermediate_callback(self, context, request_obj, text_response): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_callback.assert_called_once() + + def test_skips_intermediate_callback_for_last_block( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(2, result) + + mock_callback.assert_not_called() + + def test_skips_intermediate_callback_when_flag_false( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [False, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_callback.assert_not_called() + + def test_intermediate_callback_exception_is_swallowed( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch( + "app.services.llm.chain.executor.send_callback", + side_effect=Exception("Connection error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + # Should not raise + executor._on_block_completed(0, result) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 60456e00b..cc67a7d6e 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -23,11 +23,14 @@ # KaapiLLMParams, KaapiCompletionConfig, ) -from app.models.llm.request import ConfigBlob, LLMCallConfig +from app.models.llm.request import ConfigBlob, LLMCallConfig, LLMChainRequest +from app.models.llm.request import ChainBlock as ChainBlockModel from app.services.llm.jobs import ( start_job, + start_chain_job, handle_job_error, execute_job, + execute_chain_job, resolve_config_blob, ) from app.tests.utils.utils import get_project @@ -367,7 +370,7 @@ def test_exception_during_execution( result = self._execute_job(job_for_execution, db, request_data) assert not result["success"] - assert "Unexpected error during LLM execution" in result["error"] + assert "Unexpected error occurred" in result["error"] def test_exception_during_provider_retrieval( self, db, job_env, job_for_execution, request_data @@ -1108,16 +1111,21 @@ def test_execute_job_fetches_validator_configs_from_blob_refs( result = self._execute_job(job_for_execution, db, request_data) assert result["success"] - mock_fetch_configs.assert_called_once() - _, kwargs = mock_fetch_configs.call_args - input_validator_configs = kwargs["input_validator_configs"] - output_validator_configs = kwargs["output_validator_configs"] - assert [v.validator_config_id for v in input_validator_configs] == [ - UUID(VALIDATOR_CONFIG_ID_1) - ] - assert [v.validator_config_id for v in output_validator_configs] == [ - UUID(VALIDATOR_CONFIG_ID_2) - ] + assert mock_fetch_configs.call_count == 2 + + # First call: input guardrails + _, input_kwargs = mock_fetch_configs.call_args_list[0] + assert [ + v.validator_config_id for v in input_kwargs["input_validator_configs"] + ] == [UUID(VALIDATOR_CONFIG_ID_1)] + assert input_kwargs["output_validator_configs"] is None + + # Second call: output guardrails + _, output_kwargs = mock_fetch_configs.call_args_list[1] + assert output_kwargs["input_validator_configs"] is None + assert [ + v.validator_config_id for v in output_kwargs["output_validator_configs"] + ] == [UUID(VALIDATOR_CONFIG_ID_2)] def test_execute_job_continues_when_no_validator_configs_resolved( self, db, job_env, job_for_execution @@ -1156,6 +1164,207 @@ def test_execute_job_continues_when_no_validator_configs_resolved( mock_guardrails.assert_not_called() +class TestStartChainJob: + """Test cases for the start_chain_job function.""" + + @pytest.fixture + def chain_request(self): + return LLMChainRequest( + query=QueryParams(input="Test query"), + blocks=[ + ChainBlockModel( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + ) + ], + ) + + def test_start_chain_job_success(self, db: Session, chain_request): + project = get_project(db) + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_schedule, + patch("app.services.llm.jobs.JobCrud") as mock_job_crud_class, + ): + mock_schedule.return_value = "fake-task-id" + mock_job = MagicMock() + mock_job.id = uuid4() + mock_job.job_type = JobType.LLM_CHAIN + mock_job.status = JobStatus.PENDING + mock_job_crud_class.return_value.create.return_value = mock_job + + job_id = start_chain_job( + db, chain_request, project.id, project.organization_id + ) + + assert job_id == mock_job.id + mock_schedule.assert_called_once() + _, kwargs = mock_schedule.call_args + assert kwargs["function_path"] == "app.services.llm.jobs.execute_chain_job" + + def test_start_chain_job_celery_failure(self, db: Session, chain_request): + project = get_project(db) + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_schedule, + patch("app.services.llm.jobs.JobCrud") as mock_job_crud_class, + ): + mock_schedule.side_effect = Exception("Celery connection failed") + mock_job = MagicMock() + mock_job.id = uuid4() + mock_job_crud_class.return_value.create.return_value = mock_job + + with pytest.raises(HTTPException) as exc_info: + start_chain_job(db, chain_request, project.id, project.organization_id) + + assert exc_info.value.status_code == 500 + assert "Internal server error while executing LLM chain job" in str( + exc_info.value.detail + ) + + +class TestExecuteChainJob: + """Test suite for execute_chain_job.""" + + @pytest.fixture + def chain_request_data(self): + return { + "query": {"input": "Test query"}, + "blocks": [ + { + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4"}, + } + } + }, + } + ], + } + + @pytest.fixture + def mock_llm_response(self): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-123", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=TextContent(value="Test response")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + def _execute_chain_job(self, request_data): + return execute_chain_job( + request_data=request_data, + project_id=1, + organization_id=1, + job_id=str(uuid4()), + task_id="task-123", + task_instance=None, + ) + + def test_success_flow(self, chain_request_data, mock_llm_response): + from app.services.llm.chain.types import BlockResult + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch("app.services.llm.chain.executor.Session") as mock_executor_session, + patch("app.services.llm.chain.executor.send_callback"), + patch("app.services.llm.chain.executor.update_llm_chain_status"), + patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute_llm, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + mock_executor_session.return_value.__enter__.return_value = MagicMock() + mock_executor_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = uuid4() + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + + mock_execute_llm.return_value = BlockResult( + response=mock_llm_response, + llm_call_id=uuid4(), + usage=mock_llm_response.usage, + ) + + result = self._execute_chain_job(chain_request_data) + + assert result["success"] is True + + def test_exception_during_chain_creation(self, chain_request_data): + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch( + "app.services.llm.jobs.create_llm_chain", + side_effect=Exception("DB error"), + ), + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + mock_handle_error.return_value = { + "success": False, + "error": "Unexpected error occurred", + } + + result = self._execute_chain_job(chain_request_data) + + assert result["success"] is False + + def test_chain_status_updated_to_failed_on_error(self, chain_request_data): + chain_id = uuid4() + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch( + "app.services.llm.jobs.update_llm_chain_status" + ) as mock_update_status, + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + patch( + "app.services.llm.chain.chain.LLMChain", + side_effect=Exception("Chain init error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = chain_id + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + mock_handle_error.return_value = { + "success": False, + "error": "Unexpected error occurred", + } + + result = self._execute_chain_job(chain_request_data) + + mock_update_status.assert_called_once() + _, kwargs = mock_update_status.call_args + assert kwargs["chain_id"] == chain_id + assert kwargs["status"].value == "failed" + + class TestResolveConfigBlob: """Test suite for resolve_config_blob function."""