3636 LLMCallConfig ,
3737 PDFInput ,
3838 QueryParams ,
39+ TextContent ,
3940 TextInput ,
4041)
41- from app .models .llm .response import TextOutput
42+ from app .models .llm .response import LLMCallResponse , LLMResponse , TextOutput , Usage
4243from app .services .llm .chain .types import BlockResult
4344from app .services .llm .guardrails import (
4445 list_validators_config ,
@@ -264,18 +265,24 @@ def apply_input_guardrails(
264265 job_id : UUID ,
265266 project_id : int ,
266267 organization_id : int ,
267- ) -> tuple [QueryParams , str | None ]:
268- """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain."""
268+ ) -> tuple [QueryParams , str | None , str | None ]:
269+ """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.
270+
271+ Returns (query, error, guardrail_direct_response) where:
272+ - error is set when guardrails hard-block the request
273+ - guardrail_direct_response is set when rephrase_needed=True and the safe_text
274+ should be returned directly to the user without hitting the LLM
275+ """
269276 if not config_blob or not config_blob .input_guardrails :
270- return query , None
277+ return query , None , None
271278
272279 if not isinstance (query .input , TextInput ):
273280 logger .info (
274281 f"[apply_input_guardrails] Skipping for non-text input. "
275282 f"job_id={ job_id } , "
276283 f"input_type={ getattr (query .input , 'type' , type (query .input ).__name__ )} "
277284 )
278- return query , None
285+ return query , None , None
279286
280287 input_guardrails , _ = list_validators_config (
281288 organization_id = organization_id ,
@@ -285,7 +292,7 @@ def apply_input_guardrails(
285292 )
286293
287294 if not input_guardrails :
288- return query , None
295+ return query , None , None
289296
290297 safe = run_guardrails_validation (
291298 query .input .content .value ,
@@ -304,13 +311,19 @@ def apply_input_guardrails(
304311 logger .info (
305312 f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={ job_id } "
306313 )
307- return query , None
314+ return query , None , None
308315
309316 if safe ["success" ]:
310- query .input .content .value = safe ["data" ]["safe_text" ]
311- return query , None
317+ safe_text = safe ["data" ]["safe_text" ]
318+ if safe ["data" ].get ("rephrase_needed" ):
319+ logger .info (
320+ f"[apply_input_guardrails] rephrase_needed=True, returning safe_text directly | job_id={ job_id } "
321+ )
322+ return query , None , safe_text
323+ query .input .content .value = safe_text
324+ return query , None , None
312325
313- return query , safe ["error" ]
326+ return query , safe ["error" ], None
314327
315328
316329def apply_output_guardrails (
@@ -418,13 +431,31 @@ def execute_llm_call(
418431
419432 with tracer .start_as_current_span ("llm.guardrails.input" ) as guard_span :
420433 guard_span .set_attribute ("llm.job_id" , str (job_id ))
421- query , input_error = apply_input_guardrails (
434+ query , input_error , guardrail_direct_response = apply_input_guardrails (
422435 config_blob = config_blob ,
423436 query = query ,
424437 job_id = job_id ,
425438 project_id = project_id ,
426439 organization_id = organization_id ,
427440 )
441+ if guardrail_direct_response is not None :
442+ guardrail_usage = Usage (
443+ input_tokens = 0 ,
444+ output_tokens = 0 ,
445+ total_tokens = 0 ,
446+ )
447+ llm_response = LLMCallResponse (
448+ response = LLMResponse (
449+ provider_response_id = str (job_id ),
450+ provider = "guardrail" ,
451+ model = "guardrail" ,
452+ output = TextOutput (
453+ content = TextContent (value = guardrail_direct_response )
454+ ),
455+ ),
456+ usage = guardrail_usage ,
457+ )
458+ return BlockResult (response = llm_response , usage = guardrail_usage )
428459 if input_error :
429460 guard_span .set_status (
430461 trace .Status (trace .StatusCode .ERROR , input_error )
0 commit comments