Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 128 additions & 3 deletions backend/routes/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import json
import logging
import re
import uuid
from typing import Any, Dict, List

Expand Down Expand Up @@ -104,6 +105,11 @@

DEFAULT_FAST_MODEL = os.environ.get("ANTHROPIC_FAST_MODEL", "claude-haiku-4-5")
DEFAULT_COMPLEX_MODEL = os.environ.get("ANTHROPIC_COMPLEX_MODEL", "claude-sonnet-4-6")
# Reasoning-heavy model for reform / distributional questions where Haiku
# burns through the iteration budget guessing at the reform-API shape. Opus
# typically converges in 2–4 iterations on the same prompt, so net cost goes
# down even though per-turn cost is higher.
DEFAULT_REASONING_MODEL = os.environ.get("ANTHROPIC_REASONING_MODEL", "claude-opus-4-5")
TITLE_MODEL = os.environ.get("ANTHROPIC_TITLE_MODEL", DEFAULT_FAST_MODEL)
# Follow-up suggestion chips run on the same fast model — cheap, latency-tolerant.
SUGGESTION_MODEL = os.environ.get("ANTHROPIC_SUGGESTION_MODEL", DEFAULT_FAST_MODEL)
Expand Down Expand Up @@ -159,7 +165,122 @@ def _estimate_message_tokens(messages: List[dict]) -> int:
return char_count // 4


def _select_chat_model(messages: List[dict]) -> str:
# Reform / distributional / reasoning-heavy signals. Each pattern targets
# a concrete kind of question that Haiku tends to get stuck on (reform-API
# shape, decile aggregations, marginal-rate reasoning). Patterns are
# case-insensitive substrings or word-boundary regexes; kept hardcoded for
# now so we can iterate quickly on the list. Closes #83.
_REFORM_KEYWORDS: List[str] = [
# Distributional analysis vocabulary
"decile",
"quintile",
"distributional",
"winners",
"losers",
"poverty",
"inequality",
"gini",
# Reform-shape verbs and nouns
"reform",
"increase the",
"raise the",
"cut the",
"change the",
"replace",
"freeze",
"uprate",
"bump",
# Marginal / effective rate reasoning
"marginal rate",
"effective rate",
"marginal tax",
"effective tax",
# Magnitude expressions that usually accompany reforms
"percentage point", # also matches "percentage points"
"1pp",
]

# Regex for magnitude expressions that aren't naturally captured as substrings:
# "by 5%" — bump amounts
# "from 20% to 25%" — rate changes
# "2pp" / "2 pp" — percentage-point shorthand with any digit
_REFORM_REGEX = re.compile(
r"(?:\bby\s+\d+(?:\.\d+)?\s*%)"
r"|(?:\bfrom\s+\d+(?:\.\d+)?\s*%\s*to\s+\d+(?:\.\d+)?\s*%)"
r"|(?:\b\d+\s*pp\b)",
re.IGNORECASE,
)


def _last_user_text(messages: List[dict]) -> str:
"""Return the text of the most recent user message, flattening structured
content blocks (e.g. image + text) into a single string."""
for msg in reversed(messages):
if msg.get("role") != "user":
continue
content = msg.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
parts.append(str(block.get("text", "")))
return " ".join(parts)
return str(content)
return ""


def _detect_reform_signal(text: str) -> str | None:
"""Return the matched signal string if `text` looks like a reform /
distributional / reasoning-heavy question, else None. Cheap pure-Python
checks — no LLM call."""
if not text:
return None
lowered = text.lower()
for kw in _REFORM_KEYWORDS:
if kw in lowered:
return kw
m = _REFORM_REGEX.search(lowered)
if m:
return m.group(0)
return None


def _select_chat_model(
messages: List[dict],
*,
plan_mode: bool = False,
charts_mode: bool = False,
) -> str:
"""Route to a model for this turn.

Decision order:
1. Plan mode → fast model. Plan turns just ask 1–3 clarifying questions,
no tool use, low cognitive load. Never upgrade.
2. Reform / distributional signal in the last user message → reasoning
model (Opus). Haiku burns iterations guessing at the reform API; Opus
converges in 2–4 iterations on the same prompt.
3. Charts mode → reasoning model. Charts usually imply distributional
analysis (decile / percentile / trend), which is the same failure
mode as (2).
4. Conversation too large for the fast model's effective context →
complex model (Sonnet) as before.
5. Otherwise → fast model.
"""
if plan_mode:
return DEFAULT_FAST_MODEL

last_user = _last_user_text(messages)
signal = _detect_reform_signal(last_user)
if signal:
logger.info(f"[MODEL] Routed to Opus (reform signal: {signal!r})")
return DEFAULT_REASONING_MODEL

if charts_mode:
logger.info("[MODEL] Routed to Opus (charts_mode=True)")
return DEFAULT_REASONING_MODEL

estimated_input_tokens = (
_estimate_message_tokens(messages)
+ len(SYSTEM_PROMPT) // 4
Expand Down Expand Up @@ -426,10 +547,14 @@ async def generate_stream():
recent_tool_calls: List[str] = []

client = _get_anthropic_client()
model = _select_chat_model(conversation)
tools = _tool_defs_for_anthropic()
plan_mode = chat_request.plan_mode
charts_mode = chat_request.charts_mode
model = _select_chat_model(
conversation,
plan_mode=plan_mode,
charts_mode=charts_mode,
)
tools = _tool_defs_for_anthropic()
system_blocks = _build_system_blocks(plan_mode=plan_mode, charts_mode=charts_mode)

logger.info(
Expand Down
32 changes: 32 additions & 0 deletions backend/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,38 @@ def test_request_accepts_plan_mode_field(self):
assert req2.plan_mode is False


# ---------------------------------------------------------------------------
# Model routing — reform / distributional questions upgrade to Opus (#83)
# ---------------------------------------------------------------------------

class TestSelectChatModel:
"""Pure-Python checks on the model routing heuristic.

These don't hit Anthropic — they exercise `_select_chat_model` directly
with synthetic message lists and verify the routing decision.
"""

def test_decile_reform_routes_to_opus(self):
from routes.chatbot import _select_chat_model, DEFAULT_REASONING_MODEL
msgs = [{"role": "user", "content": "Show me the decile impact of a reform raising the personal allowance by 5%"}]
assert _select_chat_model(msgs) == DEFAULT_REASONING_MODEL

def test_plain_question_routes_to_fast_model(self):
from routes.chatbot import _select_chat_model, DEFAULT_FAST_MODEL
msgs = [{"role": "user", "content": "What is the personal allowance for 2025?"}]
assert _select_chat_model(msgs) == DEFAULT_FAST_MODEL

def test_plan_mode_overrides_reform_signal(self):
from routes.chatbot import _select_chat_model, DEFAULT_FAST_MODEL
msgs = [{"role": "user", "content": "decile breakdown of a reform"}]
assert _select_chat_model(msgs, plan_mode=True) == DEFAULT_FAST_MODEL

def test_charts_mode_upgrades_without_reform_signal(self):
from routes.chatbot import _select_chat_model, DEFAULT_REASONING_MODEL
msgs = [{"role": "user", "content": "Plot the income tax schedule"}]
assert _select_chat_model(msgs, charts_mode=True) == DEFAULT_REASONING_MODEL


# ---------------------------------------------------------------------------
# Rate limiting
# ---------------------------------------------------------------------------
Expand Down