Skip to content

Commit c2b39fc

Browse files
committed
guardrails: adding rephrase support
1 parent 7135ec2 commit c2b39fc

1 file changed

Lines changed: 42 additions & 11 deletions

File tree

backend/app/services/llm/jobs.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
LLMCallConfig,
3737
PDFInput,
3838
QueryParams,
39+
TextContent,
3940
TextInput,
4041
)
41-
from app.models.llm.response import TextOutput
42+
from app.models.llm.response import LLMCallResponse, LLMResponse, TextOutput, Usage
4243
from app.services.llm.chain.types import BlockResult
4344
from app.services.llm.guardrails import (
4445
list_validators_config,
@@ -264,18 +265,24 @@ def apply_input_guardrails(
264265
job_id: UUID,
265266
project_id: int,
266267
organization_id: int,
267-
) -> tuple[QueryParams, str | None]:
268-
"""Apply input guardrails from a config_blob. Shared with llm-call and llm-chain."""
268+
) -> tuple[QueryParams, str | None, str | None]:
269+
"""Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.
270+
271+
Returns (query, error, guardrail_direct_response) where:
272+
- error is set when guardrails hard-block the request
273+
- guardrail_direct_response is set when rephrase_needed=True and the safe_text
274+
should be returned directly to the user without hitting the LLM
275+
"""
269276
if not config_blob or not config_blob.input_guardrails:
270-
return query, None
277+
return query, None, None
271278

272279
if not isinstance(query.input, TextInput):
273280
logger.info(
274281
f"[apply_input_guardrails] Skipping for non-text input. "
275282
f"job_id={job_id}, "
276283
f"input_type={getattr(query.input, 'type', type(query.input).__name__)}"
277284
)
278-
return query, None
285+
return query, None, None
279286

280287
input_guardrails, _ = list_validators_config(
281288
organization_id=organization_id,
@@ -285,7 +292,7 @@ def apply_input_guardrails(
285292
)
286293

287294
if not input_guardrails:
288-
return query, None
295+
return query, None, None
289296

290297
safe = run_guardrails_validation(
291298
query.input.content.value,
@@ -304,13 +311,19 @@ def apply_input_guardrails(
304311
logger.info(
305312
f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}"
306313
)
307-
return query, None
314+
return query, None, None
308315

309316
if safe["success"]:
310-
query.input.content.value = safe["data"]["safe_text"]
311-
return query, None
317+
safe_text = safe["data"]["safe_text"]
318+
if safe["data"].get("rephrase_needed"):
319+
logger.info(
320+
f"[apply_input_guardrails] rephrase_needed=True, returning safe_text directly | job_id={job_id}"
321+
)
322+
return query, None, safe_text
323+
query.input.content.value = safe_text
324+
return query, None, None
312325

313-
return query, safe["error"]
326+
return query, safe["error"], None
314327

315328

316329
def apply_output_guardrails(
@@ -418,13 +431,31 @@ def execute_llm_call(
418431

419432
with tracer.start_as_current_span("llm.guardrails.input") as guard_span:
420433
guard_span.set_attribute("llm.job_id", str(job_id))
421-
query, input_error = apply_input_guardrails(
434+
query, input_error, guardrail_direct_response = apply_input_guardrails(
422435
config_blob=config_blob,
423436
query=query,
424437
job_id=job_id,
425438
project_id=project_id,
426439
organization_id=organization_id,
427440
)
441+
if guardrail_direct_response is not None:
442+
guardrail_usage = Usage(
443+
input_tokens=0,
444+
output_tokens=0,
445+
total_tokens=0,
446+
)
447+
llm_response = LLMCallResponse(
448+
response=LLMResponse(
449+
provider_response_id=str(job_id),
450+
provider="guardrail",
451+
model="guardrail",
452+
output=TextOutput(
453+
content=TextContent(value=guardrail_direct_response)
454+
),
455+
),
456+
usage=guardrail_usage,
457+
)
458+
return BlockResult(response=llm_response, usage=guardrail_usage)
428459
if input_error:
429460
guard_span.set_status(
430461
trace.Status(trace.StatusCode.ERROR, input_error)

0 commit comments

Comments
 (0)