diff --git a/backend/routes/chatbot.py b/backend/routes/chatbot.py index 7bd96e2..c3bbbca 100644 --- a/backend/routes/chatbot.py +++ b/backend/routes/chatbot.py @@ -5,6 +5,7 @@ import asyncio import json import logging +import re import uuid from typing import Any, Dict, List @@ -104,6 +105,11 @@ DEFAULT_FAST_MODEL = os.environ.get("ANTHROPIC_FAST_MODEL", "claude-haiku-4-5") DEFAULT_COMPLEX_MODEL = os.environ.get("ANTHROPIC_COMPLEX_MODEL", "claude-sonnet-4-6") +# Reasoning-heavy model for reform / distributional questions where Haiku +# burns through the iteration budget guessing at the reform-API shape. Opus +# typically converges in 2–4 iterations on the same prompt, so net cost goes +# down even though per-turn cost is higher. +DEFAULT_REASONING_MODEL = os.environ.get("ANTHROPIC_REASONING_MODEL", "claude-opus-4-5") TITLE_MODEL = os.environ.get("ANTHROPIC_TITLE_MODEL", DEFAULT_FAST_MODEL) # Follow-up suggestion chips run on the same fast model — cheap, latency-tolerant. SUGGESTION_MODEL = os.environ.get("ANTHROPIC_SUGGESTION_MODEL", DEFAULT_FAST_MODEL) @@ -159,7 +165,122 @@ def _estimate_message_tokens(messages: List[dict]) -> int: return char_count // 4 -def _select_chat_model(messages: List[dict]) -> str: +# Reform / distributional / reasoning-heavy signals. Each pattern targets +# a concrete kind of question that Haiku tends to get stuck on (reform-API +# shape, decile aggregations, marginal-rate reasoning). Patterns are +# case-insensitive substrings or word-boundary regexes; kept hardcoded for +# now so we can iterate quickly on the list. Closes #83. +_REFORM_KEYWORDS: List[str] = [ + # Distributional analysis vocabulary + "decile", + "quintile", + "distributional", + "winners", + "losers", + "poverty", + "inequality", + "gini", + # Reform-shape verbs and nouns + "reform", + "increase the", + "raise the", + "cut the", + "change the", + "replace", + "freeze", + "uprate", + "bump", + # Marginal / effective rate reasoning + "marginal rate", + "effective rate", + "marginal tax", + "effective tax", + # Magnitude expressions that usually accompany reforms + "percentage point", # also matches "percentage points" + "1pp", +] + +# Regex for magnitude expressions that aren't naturally captured as substrings: +# "by 5%" — bump amounts +# "from 20% to 25%" — rate changes +# "2pp" / "2 pp" — percentage-point shorthand with any digit +_REFORM_REGEX = re.compile( + r"(?:\bby\s+\d+(?:\.\d+)?\s*%)" + r"|(?:\bfrom\s+\d+(?:\.\d+)?\s*%\s*to\s+\d+(?:\.\d+)?\s*%)" + r"|(?:\b\d+\s*pp\b)", + re.IGNORECASE, +) + + +def _last_user_text(messages: List[dict]) -> str: + """Return the text of the most recent user message, flattening structured + content blocks (e.g. image + text) into a single string.""" + for msg in reversed(messages): + if msg.get("role") != "user": + continue + content = msg.get("content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + parts.append(str(block.get("text", ""))) + return " ".join(parts) + return str(content) + return "" + + +def _detect_reform_signal(text: str) -> str | None: + """Return the matched signal string if `text` looks like a reform / + distributional / reasoning-heavy question, else None. Cheap pure-Python + checks — no LLM call.""" + if not text: + return None + lowered = text.lower() + for kw in _REFORM_KEYWORDS: + if kw in lowered: + return kw + m = _REFORM_REGEX.search(lowered) + if m: + return m.group(0) + return None + + +def _select_chat_model( + messages: List[dict], + *, + plan_mode: bool = False, + charts_mode: bool = False, +) -> str: + """Route to a model for this turn. + + Decision order: + 1. Plan mode → fast model. Plan turns just ask 1–3 clarifying questions, + no tool use, low cognitive load. Never upgrade. + 2. Reform / distributional signal in the last user message → reasoning + model (Opus). Haiku burns iterations guessing at the reform API; Opus + converges in 2–4 iterations on the same prompt. + 3. Charts mode → reasoning model. Charts usually imply distributional + analysis (decile / percentile / trend), which is the same failure + mode as (2). + 4. Conversation too large for the fast model's effective context → + complex model (Sonnet) as before. + 5. Otherwise → fast model. + """ + if plan_mode: + return DEFAULT_FAST_MODEL + + last_user = _last_user_text(messages) + signal = _detect_reform_signal(last_user) + if signal: + logger.info(f"[MODEL] Routed to Opus (reform signal: {signal!r})") + return DEFAULT_REASONING_MODEL + + if charts_mode: + logger.info("[MODEL] Routed to Opus (charts_mode=True)") + return DEFAULT_REASONING_MODEL + estimated_input_tokens = ( _estimate_message_tokens(messages) + len(SYSTEM_PROMPT) // 4 @@ -426,10 +547,14 @@ async def generate_stream(): recent_tool_calls: List[str] = [] client = _get_anthropic_client() - model = _select_chat_model(conversation) - tools = _tool_defs_for_anthropic() plan_mode = chat_request.plan_mode charts_mode = chat_request.charts_mode + model = _select_chat_model( + conversation, + plan_mode=plan_mode, + charts_mode=charts_mode, + ) + tools = _tool_defs_for_anthropic() system_blocks = _build_system_blocks(plan_mode=plan_mode, charts_mode=charts_mode) logger.info( diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 8ede926..9dd3050 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -290,6 +290,38 @@ def test_request_accepts_plan_mode_field(self): assert req2.plan_mode is False +# --------------------------------------------------------------------------- +# Model routing — reform / distributional questions upgrade to Opus (#83) +# --------------------------------------------------------------------------- + +class TestSelectChatModel: + """Pure-Python checks on the model routing heuristic. + + These don't hit Anthropic — they exercise `_select_chat_model` directly + with synthetic message lists and verify the routing decision. + """ + + def test_decile_reform_routes_to_opus(self): + from routes.chatbot import _select_chat_model, DEFAULT_REASONING_MODEL + msgs = [{"role": "user", "content": "Show me the decile impact of a reform raising the personal allowance by 5%"}] + assert _select_chat_model(msgs) == DEFAULT_REASONING_MODEL + + def test_plain_question_routes_to_fast_model(self): + from routes.chatbot import _select_chat_model, DEFAULT_FAST_MODEL + msgs = [{"role": "user", "content": "What is the personal allowance for 2025?"}] + assert _select_chat_model(msgs) == DEFAULT_FAST_MODEL + + def test_plan_mode_overrides_reform_signal(self): + from routes.chatbot import _select_chat_model, DEFAULT_FAST_MODEL + msgs = [{"role": "user", "content": "decile breakdown of a reform"}] + assert _select_chat_model(msgs, plan_mode=True) == DEFAULT_FAST_MODEL + + def test_charts_mode_upgrades_without_reform_signal(self): + from routes.chatbot import _select_chat_model, DEFAULT_REASONING_MODEL + msgs = [{"role": "user", "content": "Plot the income tax schedule"}] + assert _select_chat_model(msgs, charts_mode=True) == DEFAULT_REASONING_MODEL + + # --------------------------------------------------------------------------- # Rate limiting # ---------------------------------------------------------------------------