diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 436b26a58..60dbdb49e 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -36,9 +36,10 @@ LLMCallConfig, PDFInput, QueryParams, + TextContent, TextInput, ) -from app.models.llm.response import TextOutput +from app.models.llm.response import LLMCallResponse, LLMResponse, TextOutput, Usage from app.services.llm.chain.types import BlockResult from app.services.llm.guardrails import ( list_validators_config, @@ -264,10 +265,16 @@ def apply_input_guardrails( job_id: UUID, project_id: int, organization_id: int, -) -> tuple[QueryParams, str | None]: - """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.""" +) -> tuple[QueryParams, str | None, str | None]: + """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain. + + Returns (query, error, guardrail_direct_response) where: + - error is set when guardrails hard-block the request + - guardrail_direct_response is set when rephrase_needed=True and the safe_text + should be returned directly to the user without hitting the LLM + """ if not config_blob or not config_blob.input_guardrails: - return query, None + return query, None, None if not isinstance(query.input, TextInput): logger.info( @@ -275,7 +282,7 @@ def apply_input_guardrails( f"job_id={job_id}, " f"input_type={getattr(query.input, 'type', type(query.input).__name__)}" ) - return query, None + return query, None, None input_guardrails, _ = list_validators_config( organization_id=organization_id, @@ -285,7 +292,7 @@ def apply_input_guardrails( ) if not input_guardrails: - return query, None + return query, None, None safe = run_guardrails_validation( query.input.content.value, @@ -304,13 +311,19 @@ def apply_input_guardrails( logger.info( f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" ) - return query, None + return query, None, None if safe["success"]: - query.input.content.value = safe["data"]["safe_text"] - return query, None + safe_text = safe["data"]["safe_text"] + if safe["data"].get("rephrase_needed"): + logger.info( + f"[apply_input_guardrails] rephrase_needed=True, returning safe_text directly | job_id={job_id}" + ) + return query, None, safe_text + query.input.content.value = safe_text + return query, None, None - return query, safe["error"] + return query, safe["error"], None def apply_output_guardrails( @@ -418,13 +431,35 @@ def execute_llm_call( with tracer.start_as_current_span("llm.guardrails.input") as guard_span: guard_span.set_attribute("llm.job_id", str(job_id)) - query, input_error = apply_input_guardrails( + query, input_error, guardrail_direct_response = apply_input_guardrails( config_blob=config_blob, query=query, job_id=job_id, project_id=project_id, organization_id=organization_id, ) + if guardrail_direct_response is not None: + guardrail_usage = Usage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + ) + llm_response = LLMCallResponse( + response=LLMResponse( + provider_response_id=str(job_id), + provider=str(config_blob.completion.provider), + model=str(config_blob.completion.params.get("model") or ""), + output=TextOutput( + content=TextContent(value=guardrail_direct_response) + ), + ), + usage=guardrail_usage, + ) + return BlockResult( + response=llm_response, + usage=guardrail_usage, + metadata=request_metadata, + ) if input_error: guard_span.set_status( trace.Status(trace.StatusCode.ERROR, input_error) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index c1b6acf25..e838796ad 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -1074,9 +1074,10 @@ def test_guardrails_rephrase_needed_allows_job_with_sanitized_input( result = self._execute_job(job_for_execution, db, request_data) assert result["success"] is True - env["provider"].execute.assert_called_once() - provider_query = env["provider"].execute.call_args[0][1] - assert provider_query.input.content.value == "Rephrased text" + env["provider"].execute.assert_not_called() + assert ( + result["data"]["response"]["output"]["content"]["value"] == "Rephrased text" + ) def test_execute_job_fetches_validator_configs_from_blob_refs( self, db, job_env, job_for_execution