diff --git a/.github/workflows/pr-beta-deploy.yml b/.github/workflows/pr-beta-deploy.yml index aee81d3..313b9de 100644 --- a/.github/workflows/pr-beta-deploy.yml +++ b/.github/workflows/pr-beta-deploy.yml @@ -89,6 +89,7 @@ jobs: SUPABASE_SERVICE_ROLE_KEY="$SUPABASE_SERVICE_ROLE_KEY" \ STRIPE_SECRET_KEY="$STRIPE_SECRET_KEY" \ STRIPE_WEBHOOK_SECRET="$STRIPE_WEBHOOK_SECRET" \ + POLICYENGINE_CHAT_TOPIC_GATE_ENABLED=true \ HOSTNAME_REGEX="^https://policyengine-uk-chat-git-[a-z0-9-]+-policy-engine\.vercel\.app$" \ PUBLIC_BASE_URL="https://policyengine-uk-chat.vercel.app" \ --force @@ -135,6 +136,7 @@ jobs: SUPABASE_SERVICE_ROLE_KEY="$SUPABASE_SERVICE_ROLE_KEY" \ STRIPE_SECRET_KEY="$STRIPE_SECRET_KEY" \ STRIPE_WEBHOOK_SECRET="$STRIPE_WEBHOOK_SECRET" \ + POLICYENGINE_CHAT_TOPIC_GATE_ENABLED=true \ HOSTNAME_REGEX="^https://policyengine-uk-chat-git-[a-z0-9-]+-policy-engine\.vercel\.app$" \ PUBLIC_BASE_URL="$FRONTEND_URL" \ --force diff --git a/backend/model_backends.py b/backend/model_backends.py index f21d52f..b931aac 100644 --- a/backend/model_backends.py +++ b/backend/model_backends.py @@ -338,16 +338,26 @@ def capabilities() -> Dict[str, Any]: } +# Metadata for /chat/backends. Computed lazily on first call, then cached: +# the values (id/label/version) don't change within a deploy. Avoids paying +# for importlib.metadata.version() — and the heavy package import it may +# trigger — on every request. +_BACKENDS_METADATA_CACHE: Dict[str, Dict[str, str]] | None = None + + def available_backends() -> Dict[str, Dict[str, str]]: - return { - backend_id: { - "id": backend.id, - "display_name": backend.display_name, - "package_label": backend.package_label, - "version": backend.package_version(), + global _BACKENDS_METADATA_CACHE + if _BACKENDS_METADATA_CACHE is None: + _BACKENDS_METADATA_CACHE = { + backend_id: { + "id": backend.id, + "display_name": backend.display_name, + "package_label": backend.package_label, + "version": backend.package_version(), + } + for backend_id, backend in _BACKENDS.items() } - for backend_id, backend in _BACKENDS.items() - } + return _BACKENDS_METADATA_CACHE def get_backend(backend_id: str | None = None) -> ModelBackend: diff --git a/backend/routes/chatbot.py b/backend/routes/chatbot.py index d091d6f..eee8473 100644 --- a/backend/routes/chatbot.py +++ b/backend/routes/chatbot.py @@ -96,6 +96,11 @@ def _build_system_prompt(backend_id: str) -> str: TITLE_MODEL = os.environ.get("ANTHROPIC_TITLE_MODEL", DEFAULT_FAST_MODEL) FAST_MODEL_MAX_INPUT_TOKENS = int(os.environ.get("ANTHROPIC_FAST_MODEL_MAX_INPUT_TOKENS", "120000")) +# Topic gate — short-circuits requests that are clearly off-topic before they +# hit the main loop. Opt-in via env so rollout can be staged. +TOPIC_GATE_ENABLED = os.environ.get("POLICYENGINE_CHAT_TOPIC_GATE_ENABLED", "false").lower() == "true" +TOPIC_GATE_MODEL = os.environ.get("POLICYENGINE_CHAT_TOPIC_GATE_MODEL", DEFAULT_FAST_MODEL) + _REFERENCE_PATH = Path(__file__).resolve().parent.parent / "reference.md" try: REFERENCE_DOC = _REFERENCE_PATH.read_text() @@ -261,6 +266,78 @@ def list_backends(): } +# --------------------------------------------------------------------------- +# Topic gate +# --------------------------------------------------------------------------- + +# Calibration for the boundary cases: +# - "Capital of France?" → no (reject) +# - "What did the chancellor say yesterday?" → no (reject — news, not policy) +# - "How will the PA reform affect inflation?" → yes (let through; the main +# loop's scope-refusal then explains microsim-vs-macro — see eval A4) +# - "What's the EITC?" / "How does UC taper?" → yes (factual policy) +# +# Failure mode preference: false negatives (rejecting on-topic) are worse than +# false positives (accepting off-topic). The latter wastes a few cents; the +# former breaks the product. So the prompt biases toward letting things +# through, and any classifier error short-circuits to "yes". +_TOPIC_GATE_SYSTEM = """You are a strict classifier deciding whether to forward a user's question to a UK tax-and-benefit policy assistant. + +Reply with exactly one token: "yes" or "no". + +Reply "yes" when the question is, or could plausibly be, about: +- UK or US tax, benefits, social-security, or public-finance policy +- Specific programmes (Universal Credit, EITC, CTC, SNAP, NHS, state pension, etc.) +- Household-level financial situations the assistant could simulate +- Reforms, hypothetical policy changes, distributional or budgetary effects +- Whether something is in scope for a microsimulation model (the assistant will explain limitations) +- Follow-ups, clarifications, or chit-chat that names a policy topic + +Reply "no" only when the question is unambiguously NOT about policy — e.g.: +- General knowledge (capitals, history, science, sports, weather) +- News or current events not tied to a specific policy +- Personal advice, emotional support, creative writing +- Coding help unrelated to policy modelling + +When in doubt, reply "yes". +""" + + +def _classify_on_topic(last_user_message: str) -> bool: + """Single Haiku classification call. Fail-open on any error.""" + if not last_user_message or not last_user_message.strip(): + return True + try: + client = _get_sync_anthropic_client() + response = client.messages.create( + model=TOPIC_GATE_MODEL, + max_tokens=4, + system=_TOPIC_GATE_SYSTEM, + messages=[{"role": "user", "content": last_user_message[:2000]}], + ) + text = response.content[0].text.strip().lower() if response.content else "" + return not text.startswith("no") + except Exception as e: + logger.warning(f"[CHAT] Topic gate classification failed; failing open: {e}") + return True + + +_OFF_TOPIC_REFUSAL = ( + "I'm built for UK tax and benefit policy questions — things like reform " + "impacts, eligibility, programme parameters, or distributional effects. " + "I can't help with this one. If you'd like, ask me about a UK policy or " + "household situation and I'll run the numbers." +) + + +def _off_topic_refusal_stream(session_id: str, backend_id: str, model: str): + """Emit an SSE stream the frontend renders identically to a normal answer.""" + async def gen(): + yield f"data: {json.dumps({'type': 'chunk', 'content': _OFF_TOPIC_REFUSAL})}\n\n" + yield f"data: {json.dumps({'type': 'done', 'content': _OFF_TOPIC_REFUSAL, 'session_id': session_id, 'model': model, 'model_backend': backend_id, 'usage': {'input_tokens': 0, 'output_tokens': 0, 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0}, 'cost_gbp': None, 'balance': None, 'refused_by_topic_gate': True})}\n\n" + return gen + + # --------------------------------------------------------------------------- # Chat endpoint — SSE streaming # --------------------------------------------------------------------------- @@ -287,6 +364,23 @@ async def chat_message(request: ChatRequest, http_request: Request): except ValueError as e: return JSONResponse(status_code=400, content={"error": str(e)}) + # Topic gate — short-circuit clearly off-topic messages before we load + # the system prompt, reference doc, or run any tools. Only checks the + # most recent user message; mid-conversation drift is left to the main + # loop's scope guidance. + if TOPIC_GATE_ENABLED: + last_user = next( + (m.content for m in reversed(request.messages) if m.role == "user"), + "", + ) + if not _classify_on_topic(last_user): + logger.info(f"[CHAT] Topic gate rejected message in session {session_id}") + return StreamingResponse( + _off_topic_refusal_stream(session_id, backend.id, TOPIC_GATE_MODEL)(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] # Deduplicate consecutive same-role messages diff --git a/backend/tests/test_topic_gate.py b/backend/tests/test_topic_gate.py new file mode 100644 index 0000000..31ba88c --- /dev/null +++ b/backend/tests/test_topic_gate.py @@ -0,0 +1,85 @@ +"""Tests for the optional Haiku-based topic gate in /chat/message.""" + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from routes import chatbot + + +def _stub_client_returning(text: str): + """Return a stand-in Anthropic client whose .messages.create() yields `text`.""" + response = SimpleNamespace(content=[SimpleNamespace(text=text)]) + client = SimpleNamespace(messages=SimpleNamespace(create=lambda **_: response)) + return client + + +class TestClassifyOnTopic: + """Boundary-case calibration for the Haiku classifier prompt. + + These tests don't actually call Haiku — they patch the Anthropic client + so we can lock in the parser's behaviour against representative responses. + The real classifier prompt is exercised by manual eval, not unit tests. + """ + + @pytest.mark.parametrize( + "model_reply,expected", + [ + ("yes", True), + ("Yes", True), + ("yes.", True), + ("YES — clearly policy", True), + ("no", False), + ("No.", False), + ("no, off-topic", False), + ("", True), # malformed → fail open + ("maybe", True), # not starting with "no" → fail open + ], + ) + def test_parses_model_reply(self, model_reply, expected): + with patch.object(chatbot, "_get_sync_anthropic_client", lambda: _stub_client_returning(model_reply)): + assert chatbot._classify_on_topic("any question") is expected + + def test_empty_input_passes_through(self): + # No need to call the model at all when the message is empty. + with patch.object(chatbot, "_get_sync_anthropic_client", side_effect=AssertionError("should not be called")): + assert chatbot._classify_on_topic("") is True + assert chatbot._classify_on_topic(" ") is True + + def test_anthropic_error_fails_open(self): + def boom(): + raise RuntimeError("anthropic down") + with patch.object(chatbot, "_get_sync_anthropic_client", boom): + assert chatbot._classify_on_topic("How does Universal Credit work?") is True + + +class TestChatMessageGate: + """End-to-end gate behaviour via TestClient. + + Gate is off by default in tests (env var unset). When turned on with the + classifier stubbed to reject, /chat/message returns an SSE stream containing + the canned refusal and never invokes the heavy chat loop. + """ + + def test_gate_off_by_default(self): + # No assertions about behaviour here — just that the module imports + # and the default config keeps the gate disabled. + assert chatbot.TOPIC_GATE_ENABLED is False + + def test_gate_on_rejects_off_topic(self, monkeypatch): + from fastapi.testclient import TestClient + from main import app + + monkeypatch.setattr(chatbot, "TOPIC_GATE_ENABLED", True) + monkeypatch.setattr(chatbot, "_classify_on_topic", lambda _msg: False) + + client = TestClient(app) + resp = client.post( + "/chat/message", + json={"messages": [{"role": "user", "content": "What's the capital of France?"}]}, + ) + assert resp.status_code == 200 + body = resp.text + assert "UK tax and benefit" in body + assert "refused_by_topic_gate" in body diff --git a/frontend/src/app/ChatPage.tsx b/frontend/src/app/ChatPage.tsx index aa4a9ba..df5bd93 100644 --- a/frontend/src/app/ChatPage.tsx +++ b/frontend/src/app/ChatPage.tsx @@ -193,6 +193,7 @@ export default function ChatPage() { const abortRef = useRef(null); const [modelBackends, setModelBackends] = useState([]); + const [backendsLoading, setBackendsLoading] = useState(true); const [selectedBackendId, setSelectedBackendId] = useState("uk_compiled"); const [balance, setBalance] = useState(null); const [topUpLoading, setTopUpLoading] = useState(false); @@ -239,7 +240,8 @@ export default function ChatPage() { setModelBackends(options); setSelectedBackendId(nextBackend); }) - .catch(() => {}); + .catch(() => {}) + .finally(() => setBackendsLoading(false)); // Refresh balance after Stripe redirect if (typeof window !== "undefined" && new URLSearchParams(window.location.search).get("topup") === "success") { window.history.replaceState({}, "", window.location.pathname); @@ -1067,7 +1069,13 @@ export default function ChatPage() { {!hasMessages && Press Enter to send · Shift+Enter for new line}
- {modelBackends.length > 1 && ( + {backendsLoading && ( +
+ + Loading engines… +
+ )} + {!backendsLoading && modelBackends.length > 1 && (
Engine
diff --git a/modal_app.py b/modal_app.py index 884afef..d8004e7 100644 --- a/modal_app.py +++ b/modal_app.py @@ -15,11 +15,27 @@ def _preload_engine(): - """Bake the compiled engine into the image snapshot for fast cold starts.""" + """Bake the engines into the image snapshot for fast cold starts. + + The compiled (Rust) backend is the default and gets the full warm-up. + The Python backends only need their packages importable — that's enough + to make `/chat/backends` return without paying for the heavy + PolicyEngine Core/OpenFisca import on the first request. + """ from policyengine_uk_compiled import Simulation sim = Simulation(year=2024) sim.get_baseline_params() - print("Engine pre-loaded.") + print("Compiled engine pre-loaded.") + + # Best-effort imports of the Python backends. Failures are non-fatal — + # the chat works without them; this is purely to shave cold-start latency + # off /chat/backends. + for pkg in ("policyengine_uk", "policyengine_us"): + try: + __import__(pkg) + print(f"{pkg} pre-imported.") + except ImportError: + print(f"{pkg} not installed; skipping pre-import.") image = (