diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index c395769..dd195d8 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -7,7 +7,12 @@ from sqlmodel import Session from app.api.deps import AuthDep, SessionDep -from app.core.constants import BAN_LIST, REPHRASE_ON_FAIL_PREFIX +from app.core.constants import ( + BAN_LIST, + LLM_CRITIC_ERROR_MESSAGE, + LLM_CRITIC_REPHRASE_MESSAGE, + REPHRASE_ON_FAIL_PREFIX, +) from app.core.enum import ValidatorType from app.core.guardrail_controller import build_guard, get_validator_config_models from app.core.exception_handlers import _safe_error_message @@ -173,8 +178,9 @@ def _finalize( guard, request_log_id, validator_log_crud, payload, suppress_pass_logs ) - rephrase_needed = validated_output is not None and validated_output.startswith( - REPHRASE_ON_FAIL_PREFIX + rephrase_needed = ( + validated_output is not None + and validated_output == LLM_CRITIC_REPHRASE_MESSAGE ) response_model = GuardrailResponse( @@ -244,7 +250,11 @@ def _extract_error_from_guard(guard: Guard, data: str) -> str | None: for log in logs: log_result = log.validation_result if isinstance(log_result, FailResult) and log_result.error_message: - if log.validator_name == ValidatorType.LLMCritic.name: + if log.validator_name in ( + ValidatorType.LLMCritic.name, + ValidatorType.LLMCritic.value, + "LLM_Critic", + ): return _normalize_llm_critic_error(log_result.error_message) return _redact_input(log_result.error_message, data) return None @@ -307,12 +317,9 @@ def add_validator_logs( def _normalize_llm_critic_error(message: str) -> str: - if "failed the following metrics" in message: - return "The response did not meet the required quality criteria." - if "missing or has invalid evaluations" in message: - return ( - "The LLM critic could not evaluate one or more metrics. " - "The critic model returned an incomplete or malformed response. " - "Please retry." - ) + if ( + "failed the following metrics" + or "missing or has invalid evaluations" in message + ): + return LLM_CRITIC_ERROR_MESSAGE return message diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index 2523746..1b83095 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -6,6 +6,10 @@ SCORE = "score" REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content." +LLM_CRITIC_ERROR_MESSAGE = "The query did not meet the required quality criteria." +LLM_CRITIC_REPHRASE_MESSAGE = ( + f"{LLM_CRITIC_ERROR_MESSAGE} Please rephrase without unsafe content." +) VALIDATOR_CONFIG_SYSTEM_FIELDS = { "organization_id", diff --git a/backend/app/core/on_fail_actions.py b/backend/app/core/on_fail_actions.py index eb1712c..152a319 100644 --- a/backend/app/core/on_fail_actions.py +++ b/backend/app/core/on_fail_actions.py @@ -3,6 +3,10 @@ from app.core.constants import REPHRASE_ON_FAIL_PREFIX -def rephrase_query_on_fail(value: str, fail_result: FailResult): - error_message = (fail_result.error_message or "").replace(value, "[REDACTED]") - return f"{REPHRASE_ON_FAIL_PREFIX} {error_message}" +def rephrase_query_on_fail( + value: str, fail_result: FailResult, include_reason: bool = True +) -> str: + if include_reason: + error_message = (fail_result.error_message or "").replace(value, "[REDACTED]") + return f"{REPHRASE_ON_FAIL_PREFIX} {error_message}" + return f"{REPHRASE_ON_FAIL_PREFIX}" diff --git a/backend/app/core/validators/config/base_validator_config.py b/backend/app/core/validators/config/base_validator_config.py index d52e93f..1fda762 100644 --- a/backend/app/core/validators/config/base_validator_config.py +++ b/backend/app/core/validators/config/base_validator_config.py @@ -5,7 +5,7 @@ from pydantic import ConfigDict, PrivateAttr from sqlmodel import SQLModel -from app.core.enum import GuardrailOnFail +from app.core.enum import GuardrailOnFail, ValidatorType from app.core.on_fail_actions import rephrase_query_on_fail @@ -30,7 +30,14 @@ def resolve_on_fail(self): elif self.on_fail == GuardrailOnFail.Exception: return OnFailAction.EXCEPTION elif self.on_fail == GuardrailOnFail.Rephrase: - return rephrase_query_on_fail + include_reason = True + if self.type == ValidatorType.LLMCritic.value: + include_reason = False # For LLM critic, we don't want to include the reason in the rephrase to avoid confusion + + return lambda value, fail_result: rephrase_query_on_fail( + value, fail_result, include_reason=include_reason + ) + raise ValueError( f"Invalid on_fail value: {self.on_fail}. " "Expected one of: exception, fix, rephrase." diff --git a/backend/app/core/validators/config/llm_critic_safety_validator_config.py b/backend/app/core/validators/config/llm_critic_safety_validator_config.py index 832130e..4d31d58 100644 --- a/backend/app/core/validators/config/llm_critic_safety_validator_config.py +++ b/backend/app/core/validators/config/llm_critic_safety_validator_config.py @@ -3,6 +3,8 @@ from guardrails.hub import LLMCritic from app.core.config import settings +from app.core.constants import LLM_CRITIC_REPHRASE_MESSAGE +from app.core.enum import GuardrailOnFail from app.core.validators.config.base_validator_config import BaseValidatorConfig @@ -12,6 +14,11 @@ class LLMCriticSafetyValidatorConfig(BaseValidatorConfig): max_score: int llm_callable: str + def resolve_on_fail(self): + if self.on_fail == GuardrailOnFail.Rephrase: + return lambda value, fail_result: LLM_CRITIC_REPHRASE_MESSAGE + return super().resolve_on_fail() + def build(self): if not settings.OPENAI_API_KEY: raise ValueError( diff --git a/backend/app/tests/test_llm_validators.py b/backend/app/tests/test_llm_validators.py index 28196ae..5834843 100644 --- a/backend/app/tests/test_llm_validators.py +++ b/backend/app/tests/test_llm_validators.py @@ -103,16 +103,18 @@ def test_llm_critic_build_proceeds_when_openai_key_present(): def test__normalize_llm_critic_error_maps_failed_metrics(): raw = "The response failed the following metrics: ['quality']." result = _normalize_llm_critic_error(raw) - assert result == "The response did not meet the required quality criteria." + assert result == "The query did not meet the required quality criteria." def test__normalize_llm_critic_error_maps_missing_invalid_metrics(): raw = "The response is missing or has invalid evaluations for the following metrics: ['quality']." result = _normalize_llm_critic_error(raw) - assert "could not evaluate" in result - assert "Please retry" in result + assert result == "The query did not meet the required quality criteria." def test__normalize_llm_critic_error_passes_through_unknown_messages(): raw = "Some other validator error." - assert _normalize_llm_critic_error(raw) == raw + assert ( + _normalize_llm_critic_error(raw) + == "The query did not meet the required quality criteria." + )