Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions backend/app/services/llm/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
LLMCallConfig,
PDFInput,
QueryParams,
TextContent,
TextInput,
)
from app.models.llm.response import TextOutput
from app.models.llm.response import LLMCallResponse, LLMResponse, TextOutput, Usage
from app.services.llm.chain.types import BlockResult
from app.services.llm.guardrails import (
list_validators_config,
Expand Down Expand Up @@ -264,18 +265,24 @@ def apply_input_guardrails(
job_id: UUID,
project_id: int,
organization_id: int,
) -> tuple[QueryParams, str | None]:
"""Apply input guardrails from a config_blob. Shared with llm-call and llm-chain."""
) -> tuple[QueryParams, str | None, str | None]:
"""Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.

Returns (query, error, guardrail_direct_response) where:
- error is set when guardrails hard-block the request
- guardrail_direct_response is set when rephrase_needed=True and the safe_text
should be returned directly to the user without hitting the LLM
"""
if not config_blob or not config_blob.input_guardrails:
return query, None
return query, None, None

if not isinstance(query.input, TextInput):
logger.info(
f"[apply_input_guardrails] Skipping for non-text input. "
f"job_id={job_id}, "
f"input_type={getattr(query.input, 'type', type(query.input).__name__)}"
)
return query, None
return query, None, None

input_guardrails, _ = list_validators_config(
organization_id=organization_id,
Expand All @@ -285,7 +292,7 @@ def apply_input_guardrails(
)

if not input_guardrails:
return query, None
return query, None, None

safe = run_guardrails_validation(
query.input.content.value,
Expand All @@ -304,13 +311,19 @@ def apply_input_guardrails(
logger.info(
f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}"
)
return query, None
return query, None, None

if safe["success"]:
query.input.content.value = safe["data"]["safe_text"]
return query, None
safe_text = safe["data"]["safe_text"]
if safe["data"].get("rephrase_needed"):
logger.info(
f"[apply_input_guardrails] rephrase_needed=True, returning safe_text directly | job_id={job_id}"
)
return query, None, safe_text
query.input.content.value = safe_text
return query, None, None

return query, safe["error"]
return query, safe["error"], None


def apply_output_guardrails(
Expand Down Expand Up @@ -418,13 +431,35 @@ def execute_llm_call(

with tracer.start_as_current_span("llm.guardrails.input") as guard_span:
guard_span.set_attribute("llm.job_id", str(job_id))
query, input_error = apply_input_guardrails(
query, input_error, guardrail_direct_response = apply_input_guardrails(
config_blob=config_blob,
query=query,
job_id=job_id,
project_id=project_id,
organization_id=organization_id,
)
if guardrail_direct_response is not None:
guardrail_usage = Usage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
)
llm_response = LLMCallResponse(
response=LLMResponse(
provider_response_id=str(job_id),
provider=str(config_blob.completion.provider),
model=str(config_blob.completion.params.get("model") or ""),
output=TextOutput(
content=TextContent(value=guardrail_direct_response)
),
),
usage=guardrail_usage,
)
return BlockResult(
response=llm_response,
usage=guardrail_usage,
metadata=request_metadata,
)
if input_error:
guard_span.set_status(
trace.Status(trace.StatusCode.ERROR, input_error)
Expand Down
7 changes: 4 additions & 3 deletions backend/app/tests/services/llm/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,9 +1074,10 @@ def test_guardrails_rephrase_needed_allows_job_with_sanitized_input(
result = self._execute_job(job_for_execution, db, request_data)

assert result["success"] is True
env["provider"].execute.assert_called_once()
provider_query = env["provider"].execute.call_args[0][1]
assert provider_query.input.content.value == "Rephrased text"
env["provider"].execute.assert_not_called()
assert (
result["data"]["response"]["output"]["content"]["value"] == "Rephrased text"
)

def test_execute_job_fetches_validator_configs_from_blob_refs(
self, db, job_env, job_for_execution
Expand Down
Loading