From 9fd7354182ed506d02f806ed1ac9e2dfe81ebf46 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 13 May 2026 00:20:16 -0700 Subject: [PATCH 1/7] Add model backend selector and Claude Agent SDK runner Backends: uk_compiled (Rust) and uk_python (Python), selected per request or via POLICYENGINE_CHAT_BACKEND. Each controls its own prompt context, tool description, and execution globals. ChatPage exposes a toggle persisted in localStorage. /chat/backends lists what's available. Rebased onto main: keeps main's structural plan-mode enforcement and cached-reference cache breakpoints; backend prompt context is injected into the cached first block via a template placeholder. Claude Agent SDK runner is opt-in (POLICYENGINE_CHAT_AGENT_RUNNER=claude_sdk), mirrors the existing SSE event contract, exposes run_python via an in-process MCP tool. Also: modal_app installs from backend/requirements.txt, pins policyengine_uk==2.88.0, adds GA tracking. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/agent_tools.py | 166 ++++++++----- backend/claude_agent_sdk_runner.py | 281 ++++++++++++++++++++++ backend/model_backends.py | 370 +++++++++++++++++++++++++++++ backend/requirements.txt | 4 +- backend/routes/chatbot.py | 102 +++++--- backend/tests/test_agent_tools.py | 29 +++ backend/tests/test_api.py | 30 ++- frontend/src/app/ChatPage.tsx | 155 +++++++++--- frontend/src/app/layout.tsx | 16 ++ modal_app.py | 16 +- 10 files changed, 1023 insertions(+), 146 deletions(-) create mode 100644 backend/claude_agent_sdk_runner.py create mode 100644 backend/model_backends.py diff --git a/backend/agent_tools.py b/backend/agent_tools.py index aaefed2..fa902ff 100644 --- a/backend/agent_tools.py +++ b/backend/agent_tools.py @@ -10,6 +10,8 @@ import sys from typing import Any, Dict, List, Optional +from model_backends import get_backend, make_backend_importer + logger = logging.getLogger(__name__) @@ -538,6 +540,59 @@ def analyse_microdata( return {"error": str(e)} +def compute( + operation: str, + data: List[float], + other: Optional[List[float]] = None, +) -> Dict[str, Any]: + """Small numeric helper retained for older tests/tool dispatch paths.""" + if not data: + return {"error": "data must be non-empty"} + + try: + if operation == "diff": + return { + "result": [ + data[i + 1] - data[i] for i in range(len(data) - 1) + ] + } + if operation == "pct_change": + return { + "result": [ + 0 if data[i] == 0 else (data[i + 1] - data[i]) / data[i] * 100 + for i in range(len(data) - 1) + ] + } + if operation == "mean": + return {"result": sum(data) / len(data)} + if operation == "sum": + return {"result": sum(data)} + if operation in {"subtract", "divide", "marginal_rate"}: + if other is None: + return {"error": f"{operation} requires other"} + if len(data) != len(other): + return {"error": "data and other must have the same length"} + if operation == "subtract": + return {"result": [a - b for a, b in zip(data, other)]} + if operation == "divide": + return { + "result": [ + 0 if b == 0 else a / b for a, b in zip(data, other) + ] + } + return { + "result": [ + 0 + if other[i + 1] == other[i] + else 100 * (data[i + 1] - data[i]) / (other[i + 1] - other[i]) + for i in range(len(data) - 1) + ] + } + return {"error": f"Unknown operation: {operation}"} + except Exception as e: + return {"error": str(e)} + + def generate_chart( chart_type: str, title: str, data: List[Dict[str, Any]], x_field: str, y_fields: List[str], x_label: Optional[str] = None, y_label: Optional[str] = None, @@ -578,29 +633,10 @@ def generate_chart( return {"error": str(e)} -def run_python(code: str) -> Dict[str, Any]: - """Execute Python code with the PolicyEngine UK compiled interface preloaded. - - The code should assign its final result to a variable called `result`. - The environment includes the official Python wrapper so runs are easy to - reproduce outside the chat app. - """ - import math +def _safe_builtins_for_backend(backend_id: str) -> Dict[str, Any]: import builtins as _builtins - _ensure_compiled_package_importable() - import pandas as pd - import policyengine_uk_compiled as pe - - from policyengine_uk_compiled import ( - Simulation, - StructuralReform, - Parameters, - aggregate_microdata, - combine_microdata, - capabilities, - ensure_dataset, - ) + backend = get_backend(backend_id) safe_names = ( "range", "len", "int", "float", "str", "bool", "list", "dict", "tuple", "set", "zip", "enumerate", "map", "filter", "sorted", @@ -609,37 +645,33 @@ def run_python(code: str) -> Dict[str, Any]: "print", "any", "all", "pow", "divmod", "complex", "type", "dir", "hasattr", "getattr", ) - safe_builtins = {k: getattr(_builtins, k) for k in safe_names if hasattr(_builtins, k)} + safe_builtins = { + k: getattr(_builtins, k) for k in safe_names if hasattr(_builtins, k) + } + safe_builtins["__import__"] = make_backend_importer(backend) + return safe_builtins - try: - import numpy as np - except ImportError: - np = None + +def run_python(code: str, backend_id: str = "uk_compiled") -> Dict[str, Any]: + """Execute Python code with the selected model backend preloaded. + + The code should assign its final result to a variable called `result`. + The backend adapter controls the model-specific globals made available. + """ + backend = get_backend(backend_id) + safe_builtins = _safe_builtins_for_backend(backend.id) output_lines: List[str] = [] def safe_print(*args, **kwargs): output_lines.append(" ".join(str(a) for a in args)) safe_builtins["print"] = safe_print - safe_builtins["__import__"] = _safe_import - allowed_globals: Dict[str, Any] = { - "__builtins__": safe_builtins, - "math": math, - "json": json, - "pd": pd, - "pe": pe, - "Simulation": Simulation, - "StructuralReform": StructuralReform, - "Parameters": Parameters, - "aggregate_microdata": aggregate_microdata, - "combine_microdata": combine_microdata, - "capabilities": capabilities, - "ensure_dataset": ensure_dataset, - } - if np is not None: - allowed_globals["np"] = np - allowed_globals["numpy"] = np + try: + allowed_globals = backend.execution_globals() + except Exception as e: + return {"error": f"Backend import failed for '{backend.id}': {type(e).__name__}: {e}"} + allowed_globals["__builtins__"] = safe_builtins try: exec(code, allowed_globals) @@ -656,6 +688,7 @@ def safe_print(*args, **kwargs): if not response: response["result"] = None response["note"] = "No 'result' variable was set and nothing was printed." + response["backend"] = backend.id return response @@ -685,10 +718,16 @@ def _run_generator(code: str) -> Dict[str, Any]: return result -def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]: - logger.info(f"[TOOLS] Executing {tool_name}") +def execute_tool( + tool_name: str, + tool_input: Dict[str, Any], + backend_id: str = "uk_compiled", +) -> Dict[str, Any]: + logger.info(f"[TOOLS] Executing {tool_name} with backend={backend_id}") tools = { "run_python": run_python, + "compute": compute, + "generate_chart": generate_chart, } if tool_name not in tools: return {"error": f"Unknown tool: {tool_name}"} @@ -698,6 +737,8 @@ def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]: logger.info(f"[TOOLS] Running generator for {tool_name}") tool_input = _run_generator(tool_input["generator"]) logger.info(f"[TOOLS] Generator produced keys: {list(tool_input.keys())}") + if tool_name == "run_python" and "backend_id" not in tool_input: + tool_input = {**tool_input, "backend_id": backend_id} result = tools[tool_name](**tool_input) logger.info(f"[TOOLS] {tool_name} completed") return result @@ -706,16 +747,27 @@ def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]: return {"error": str(e)} -TOOL_DEFINITIONS = [ - { - "name": "run_python", - "description": "Execute reproducible Python code using the official PolicyEngine UK compiled interface. The environment preloads `policyengine_uk_compiled` as `pe`, plus `Simulation`, `Parameters`, `StructuralReform`, `aggregate_microdata`, `combine_microdata`, `capabilities`, `ensure_dataset`, `pd`, `np`, `json`, and `math`. Assign the final answer to `result` and use `print()` for intermediate output.", - "input_schema": { - "type": "object", - "properties": { - "code": {"type": "string", "description": "Python code to execute. Must assign the final answer to `result`. Use the preloaded PolicyEngine interface directly, for example: `sim = Simulation(year=2025)` or `policy = Parameters.model_validate({...})`."}, +def get_tool_definitions(backend_id: str = "uk_compiled") -> List[Dict[str, Any]]: + backend = get_backend(backend_id) + return [ + { + "name": "run_python", + "description": backend.tool_description(), + "input_schema": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": ( + "Python code to execute. Must assign the final answer " + "to `result`. Use the preloaded model interface directly." + ), + }, + }, + "required": ["code"], }, - "required": ["code"], }, - }, -] + ] + + +TOOL_DEFINITIONS = get_tool_definitions("uk_compiled") diff --git a/backend/claude_agent_sdk_runner.py b/backend/claude_agent_sdk_runner.py new file mode 100644 index 0000000..804e438 --- /dev/null +++ b/backend/claude_agent_sdk_runner.py @@ -0,0 +1,281 @@ +""" +Experimental Claude Agent SDK runner for the chat endpoint. + +This is deliberately opt-in. It lets us compare the Claude Agent SDK harness +against the existing direct Anthropic Messages loop while keeping the same +PolicyEngine backend registry and frontend SSE contract. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, AsyncIterator, Dict, List + +from agent_tools import execute_tool +from model_backends import get_backend + +logger = logging.getLogger(__name__) + + +def _conversation_prompt(messages: List[dict]) -> str: + lines = [ + "Continue this chat transcript. Respond only to the latest user message.", + "", + ] + for message in messages: + role = "User" if message.get("role") == "user" else "Assistant" + lines.append(f"{role}: {message.get('content', '')}") + lines.append("") + return "\n".join(lines).strip() + + +def _serialise_tool_result(result: Any) -> str: + return json.dumps(result, ensure_ascii=False, default=str) + + +def _usage_totals(usage: dict[str, Any] | None) -> dict[str, int]: + if not usage: + return { + "input_tokens": 0, + "output_tokens": 0, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + } + return { + "input_tokens": int(usage.get("input_tokens", 0) or 0), + "output_tokens": int(usage.get("output_tokens", 0) or 0), + "cache_creation_input_tokens": int( + usage.get("cache_creation_input_tokens", 0) or 0 + ), + "cache_read_input_tokens": int(usage.get("cache_read_input_tokens", 0) or 0), + } + + +def _event_type(event: dict[str, Any]) -> str: + return str(event.get("type", "")) + + +def _content_block(event: dict[str, Any]) -> dict[str, Any]: + block = event.get("content_block") + return block if isinstance(block, dict) else {} + + +def _delta(event: dict[str, Any]) -> dict[str, Any]: + delta = event.get("delta") + return delta if isinstance(delta, dict) else {} + + +async def generate_claude_agent_sdk_stream( + *, + conversation: List[dict], + system_prompt: str, + plan_mode: bool, + session_id: str, + user_id: str | None, + backend_id: str, + model: str, +) -> AsyncIterator[str]: + try: + from claude_agent_sdk import ClaudeAgentOptions, create_sdk_mcp_server, query, tool + from claude_agent_sdk.types import ( + AssistantMessage, + ResultMessage, + StreamEvent, + TextBlock, + ToolResultBlock, + ToolUseBlock, + ) + except ImportError as exc: + yield f"data: {json.dumps({'type': 'error', 'content': f'Claude Agent SDK is not installed: {exc}'})}\n\n" + return + + backend = get_backend(backend_id) + + @tool( + "run_python", + backend.tool_description(), + { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": ( + "Python code to execute. Must assign the final answer to " + "`result`. Use the preloaded model interface directly." + ), + } + }, + "required": ["code"], + }, + ) + async def run_python_tool(args: dict[str, Any]) -> dict[str, Any]: + result = execute_tool("run_python", {"code": args.get("code", "")}, backend.id) + tool_result_queue.append(_queueable_tool_result(result)) + return { + "content": [ + { + "type": "text", + "text": _serialise_tool_result(result), + } + ] + } + + mcp_server = create_sdk_mcp_server( + name="policyengine", + version="0.1.0", + tools=[run_python_tool], + ) + allowed_tools = [] if plan_mode else ["mcp__policyengine__run_python"] + options = ClaudeAgentOptions( + system_prompt=system_prompt, + model=model, + mcp_servers={"policyengine": mcp_server}, + allowed_tools=allowed_tools, + permission_mode="plan" if plan_mode else "default", + include_partial_messages=True, + max_turns=1 if plan_mode else 60, + ) + + assistant_content = "" + usage = _usage_totals(None) + tool_inputs: Dict[str, dict[str, Any]] = {} + tool_names: Dict[str, str] = {} + tool_order: List[str] = [] + tool_result_queue: List[dict[str, Any]] = [] + emitted_tool_results: set[str] = set() + announced_tool_ids: set[str] = set() + + def _queueable_tool_result(result: dict[str, Any]) -> dict[str, Any]: + result_json = _serialise_tool_result(result) + return { + "status": "error" if result.get("error") else "success", + "result_summary": ( + result_json[:5000] + "..." if len(result_json) > 5000 else result_json + ), + } + + def _flush_tool_result_events() -> list[str]: + events = [] + while tool_result_queue: + tool_id = next( + (candidate for candidate in tool_order if candidate not in emitted_tool_results), + None, + ) + if not tool_id: + break + emitted_tool_results.add(tool_id) + result = tool_result_queue.pop(0) + events.append( + f"data: {json.dumps({'type': 'tool_result', 'tool_name': 'run_python', 'tool_id': tool_id, 'status': result['status'], 'result_summary': result['result_summary']})}\n\n" + ) + return events + + try: + async for message in query( + prompt=_conversation_prompt(conversation), + options=options, + ): + for event in _flush_tool_result_events(): + yield event + + if isinstance(message, StreamEvent): + event = message.event + event_type = _event_type(event) + if event_type == "content_block_start": + block = _content_block(event) + if block.get("type") == "tool_use": + tool_id = str(block.get("id", "")) + tool_name = str(block.get("name", "run_python")) + if tool_id: + tool_names[tool_id] = tool_name + tool_inputs.setdefault(tool_id, {}) + if tool_id not in tool_order: + tool_order.append(tool_id) + if tool_id and tool_id not in announced_tool_ids: + announced_tool_ids.add(tool_id) + yield f"data: {json.dumps({'type': 'tool_start', 'tool_name': 'run_python', 'tool_id': tool_id})}\n\n" + elif event_type == "content_block_delta": + delta = _delta(event) + if delta.get("type") == "text_delta" and delta.get("text"): + text = str(delta.get("text", "")) + assistant_content += text + yield f"data: {json.dumps({'type': 'chunk', 'content': text})}\n\n" + + elif isinstance(message, AssistantMessage): + if message.usage: + usage = _usage_totals(message.usage) + for block in message.content: + if isinstance(block, TextBlock): + if not assistant_content and block.text: + assistant_content += block.text + yield f"data: {json.dumps({'type': 'chunk', 'content': block.text})}\n\n" + elif isinstance(block, ToolUseBlock): + tool_names[block.id] = block.name + tool_inputs[block.id] = block.input + if block.id not in tool_order: + tool_order.append(block.id) + yield f"data: {json.dumps({'type': 'tool_use', 'tool_name': 'run_python', 'tool_id': block.id, 'tool_input': block.input, 'status': 'pending'})}\n\n" + elif isinstance(block, ToolResultBlock): + content = block.content + if isinstance(content, list): + content_text = "\n".join( + str(item.get("text", item)) + for item in content + if isinstance(item, dict) + ) + else: + content_text = str(content or "") + result_summary = ( + content_text[:5000] + "..." + if len(content_text) > 5000 + else content_text + ) + yield f"data: {json.dumps({'type': 'tool_result', 'tool_name': 'run_python', 'tool_id': block.tool_use_id, 'status': 'error' if block.is_error else 'success', 'result_summary': result_summary})}\n\n" + + elif isinstance(message, ResultMessage): + for event in _flush_tool_result_events(): + yield event + if message.usage: + usage = _usage_totals(message.usage) + if message.result and not assistant_content: + assistant_content = message.result + yield f"data: {json.dumps({'type': 'chunk', 'content': message.result})}\n\n" + billing = None + try: + from routes.billing import record_usage + + billing = record_usage( + user_id=user_id, + session_id=session_id, + model=model, + input_tokens=usage["input_tokens"], + output_tokens=usage["output_tokens"], + cache_creation_input_tokens=usage[ + "cache_creation_input_tokens" + ], + cache_read_input_tokens=usage["cache_read_input_tokens"], + ) + except Exception as exc: + logger.warning(f"[CHAT][SDK] Failed to record usage: {exc}") + + done = { + "type": "done", + "content": assistant_content, + "session_id": session_id, + "model": model, + "model_backend": backend.id, + "agent_runner": "claude_sdk", + "usage": usage, + "cost_gbp": billing["cost_gbp"] if billing else None, + "balance": billing["balance"] if billing else None, + "sdk_session_id": message.session_id, + "sdk_cost_usd": message.total_cost_usd, + } + yield f"data: {json.dumps(done)}\n\n" + return + + yield f"data: {json.dumps({'type': 'done', 'content': assistant_content, 'session_id': session_id, 'model': model, 'model_backend': backend.id, 'agent_runner': 'claude_sdk', 'usage': usage, 'cost_gbp': None, 'balance': None})}\n\n" + except Exception as exc: + logger.exception("[CHAT][SDK] Claude Agent SDK runner failed") + yield f"data: {json.dumps({'type': 'error', 'content': str(exc)})}\n\n" diff --git a/backend/model_backends.py b/backend/model_backends.py new file mode 100644 index 0000000..f21d52f --- /dev/null +++ b/backend/model_backends.py @@ -0,0 +1,370 @@ +""" +Model backend adapters for the chat agent. + +The chat UI currently exposes one flexible Python execution tool. These +adapters keep that contract stable while allowing the preloaded model +interface, prompt context, and tool documentation to vary by backend. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from importlib.metadata import PackageNotFoundError, version +import json +import math +from pathlib import Path +import sys +from typing import Any, Callable, Dict, Iterable + + +class BackendImportError(RuntimeError): + """Raised when a selected model backend cannot be imported.""" + + +def _ensure_sibling_package_importable( + import_name: str, + sibling_candidates: Iterable[Path], +) -> None: + try: + __import__(import_name) + return + except ModuleNotFoundError: + pass + + added_paths = [] + for candidate in sibling_candidates: + if candidate.is_dir(): + candidate_str = str(candidate) + if candidate_str not in sys.path: + sys.path.insert(0, candidate_str) + added_paths.append(candidate_str) + + try: + __import__(import_name) + return + except ModuleNotFoundError as exc: + detail = str(exc) + + raise BackendImportError( + f"{import_name} is not importable. Install the package or make sure a " + "local checkout is available next to policyengine-uk-chat. " + f"Added paths: {added_paths or 'none'}. Import error: {detail}" + ) + + +@dataclass(frozen=True) +class ModelBackend: + id: str + display_name: str + package_name: str + package_label: str + import_roots: frozenset[str] + + def package_version(self) -> str: + try: + return version(self.package_name) + except PackageNotFoundError: + return "unknown" + + def prompt_context(self) -> str: + raise NotImplementedError + + def tool_description(self) -> str: + raise NotImplementedError + + def execution_globals(self) -> Dict[str, Any]: + raise NotImplementedError + + +class UKCompiledBackend(ModelBackend): + def __init__(self) -> None: + super().__init__( + id="uk_compiled", + display_name="PolicyEngine UK compiled Rust backend", + package_name="policyengine-uk-compiled", + package_label="policyengine-uk-compiled", + import_roots=frozenset( + {"json", "math", "numpy", "pandas", "policyengine_uk_compiled"} + ), + ) + + def _ensure_importable(self) -> None: + repo_parent = Path(__file__).resolve().parents[2] + _ensure_sibling_package_importable( + "policyengine_uk_compiled", + [ + repo_parent / "policyengine-uk-rust" / "interfaces" / "python", + repo_parent + / "policyengine-uk-rust-codex-debug-issue" + / "interfaces" + / "python", + ], + ) + + def prompt_context(self) -> str: + return """CRITICAL - USE THE OFFICIAL POLICYENGINE UK COMPILED INTERFACE: +- The selected backend is `uk_compiled`, the Rust-backed UK engine exposed through `policyengine_uk_compiled`. +- The Python environment preloads: + `policyengine_uk_compiled` as `pe` + `Simulation` + `Parameters` + `StructuralReform` + `aggregate_microdata` + `combine_microdata` + `capabilities` + `ensure_dataset` + `pd`, `np`, `json`, `math` +- Prefer writing code directly against those objects so the run is reproducible outside chat. +- Do not recreate policy logic manually if the package already provides it. + +COMMON WORKFLOWS FOR THIS BACKEND: +- Baseline economy-wide run: + `caps = capabilities()` + `sim = Simulation(year=2025, dataset="frs")` + `result = sim.run().model_dump()` +- Reform run: + `policy = Parameters.model_validate({"income_tax": {"personal_allowance": 15000}})` + `result = sim.run(policy=policy).model_dump()` +- Custom household run: + build `persons`, `benunits`, and `households` DataFrames, then pass them to `Simulation(...)` +- Microdata analysis: + `micro = sim.run_microdata(...)` then analyse `micro.persons`, `micro.benunits`, or `micro.households` with pandas + +MODELLING SCOPE: +- The compiled backend covers the model surface exposed by `policyengine_uk_compiled`. +- Use `capabilities()` to check what is available locally before committing to an approach.""" + + def tool_description(self) -> str: + return ( + "Execute reproducible Python code using the PolicyEngine UK compiled " + "backend. The environment preloads `policyengine_uk_compiled` as `pe`, " + "plus `Simulation`, `Parameters`, `StructuralReform`, " + "`aggregate_microdata`, `combine_microdata`, `capabilities`, " + "`ensure_dataset`, `pd`, `np`, `json`, and `math`. Assign the final " + "answer to `result` and use `print()` for short diagnostics." + ) + + def execution_globals(self) -> Dict[str, Any]: + self._ensure_importable() + import pandas as pd + import policyengine_uk_compiled as pe + + try: + import numpy as np + except ImportError: + np = None + + from policyengine_uk_compiled import ( + Parameters, + Simulation, + StructuralReform, + aggregate_microdata, + capabilities, + combine_microdata, + ensure_dataset, + ) + + globals_dict: Dict[str, Any] = { + "math": math, + "json": json, + "pd": pd, + "pe": pe, + "Simulation": Simulation, + "StructuralReform": StructuralReform, + "Parameters": Parameters, + "aggregate_microdata": aggregate_microdata, + "combine_microdata": combine_microdata, + "capabilities": capabilities, + "ensure_dataset": ensure_dataset, + } + if np is not None: + globals_dict["np"] = np + globals_dict["numpy"] = np + return globals_dict + + +class UKPolicyEnginePythonBackend(ModelBackend): + def __init__(self) -> None: + super().__init__( + id="uk_python", + display_name="PolicyEngine UK Python backend", + package_name="policyengine-uk", + package_label="policyengine-uk", + import_roots=frozenset( + { + "json", + "math", + "numpy", + "pandas", + "policyengine", + "policyengine_core", + "policyengine_uk", + "microdf", + } + ), + ) + + def _ensure_importable(self) -> None: + repo_parent = Path(__file__).resolve().parents[2] + _ensure_sibling_package_importable( + "policyengine_uk", + [ + repo_parent / "policyengine-core", + repo_parent / "policyengine.py" / "src", + repo_parent / "policyengine-uk", + ], + ) + + def prompt_context(self) -> str: + return """CRITICAL - USE THE POLICYENGINE UK PYTHON MODEL INTERFACE: +- The selected backend is `uk_python`, the Python `policyengine-uk` model package. +- This is the detailed PolicyEngine Core/OpenFisca-style UK model, not the compiled Rust wrapper. +- The Python environment preloads: + `policyengine_uk` as `pe` + `Simulation` + `Microsimulation` + `CountryTaxBenefitSystem` + `Scenario` + `capabilities` + `pd`, `np`, `json`, `math` +- If installed, the higher-level `policyengine` package is also preloaded as `policyengine`. +- Prefer writing code against `policyengine_uk` objects and formulas rather than recreating policy logic. + +COMMON WORKFLOWS FOR THIS BACKEND: +- First inspect backend details: + `result = capabilities()` +- Custom household/situation run: + `sim = Simulation(situation={...})` + `result = sim.calculate("household_net_income", 2025).tolist()` +- Microsimulation from published UK data: + `sim = Microsimulation(dataset="hf://policyengine/policyengine-uk-data/enhanced_frs_2023_24.h5")` + `result = sim.calculate("household_net_income", 2025).head().to_list()` +- Parameter reform: + pass parameter changes through `Scenario` or mutate a simulation with documented `policyengine_uk` helpers. + +MODELLING SCOPE: +- This backend exposes the Python `policyengine-uk` model surface. Its API, datasets, variables, and results can differ from `uk_compiled`. +- If a dataset is unavailable locally or requires a download/token, report that clearly instead of guessing.""" + + def tool_description(self) -> str: + return ( + "Execute reproducible Python code using the Python `policyengine-uk` " + "backend. The environment preloads `policyengine_uk` as `pe`, " + "`Simulation`, `Microsimulation`, `CountryTaxBenefitSystem`, " + "`Scenario`, `capabilities`, `pd`, `np`, `json`, and `math`; " + "the higher-level `policyengine` package is available when installed. " + "Assign the final answer to `result` and use `print()` for short diagnostics." + ) + + def execution_globals(self) -> Dict[str, Any]: + self._ensure_importable() + import pandas as pd + import policyengine_uk as pe + + try: + import numpy as np + except ImportError: + np = None + + try: + import policyengine + except ImportError: + policyengine = None + + from policyengine_uk import ( + CountryTaxBenefitSystem, + Microsimulation, + Simulation, + ) + from policyengine_uk.utils.scenario import Scenario + + def capabilities() -> Dict[str, Any]: + system = CountryTaxBenefitSystem() + variables = system.variables + parameters = system.parameters + return { + "backend": self.id, + "display_name": self.display_name, + "package": "policyengine-uk", + "interface": "Python PolicyEngine Core/OpenFisca-style model", + "preloaded": [ + "policyengine_uk as pe", + "Simulation", + "Microsimulation", + "CountryTaxBenefitSystem", + "Scenario", + "pd", + "np", + "json", + "math", + ], + "variable_count": len(variables), + "sample_variables": sorted(variables)[:50], + "parameter_root_children": sorted(parameters.children.keys()), + "dataset_notes": [ + "Pass a situation dict for household-style calculations.", + "Pass a UKSingleYearDataset, UKMultiYearDataset, DataFrame, or hf:// URL for microsimulation.", + "No default dataset is used unless POLICYENGINE_UK_DEFAULT_DATASET is set.", + ], + "comparison_note": ( + "Results may differ from uk_compiled because this backend uses " + "the Python policyengine-uk model and its datasets/API surface." + ), + } + + globals_dict: Dict[str, Any] = { + "math": math, + "json": json, + "pd": pd, + "pe": pe, + "policyengine_uk": pe, + "Simulation": Simulation, + "Microsimulation": Microsimulation, + "CountryTaxBenefitSystem": CountryTaxBenefitSystem, + "Scenario": Scenario, + "capabilities": capabilities, + } + if policyengine is not None: + globals_dict["policyengine"] = policyengine + if np is not None: + globals_dict["np"] = np + globals_dict["numpy"] = np + return globals_dict + + +_BACKENDS: Dict[str, ModelBackend] = { + "uk_compiled": UKCompiledBackend(), + "uk_python": UKPolicyEnginePythonBackend(), +} + + +def available_backends() -> Dict[str, Dict[str, str]]: + return { + backend_id: { + "id": backend.id, + "display_name": backend.display_name, + "package_label": backend.package_label, + "version": backend.package_version(), + } + for backend_id, backend in _BACKENDS.items() + } + + +def get_backend(backend_id: str | None = None) -> ModelBackend: + selected = backend_id or "uk_compiled" + if selected not in _BACKENDS: + valid = ", ".join(sorted(_BACKENDS)) + raise ValueError(f"Unknown model backend '{selected}'. Valid backends: {valid}") + return _BACKENDS[selected] + + +def make_backend_importer(backend: ModelBackend) -> Callable[..., Any]: + def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): + root_name = name.split(".")[0] + if root_name not in backend.import_roots: + raise ImportError( + f"Import of '{name}' is not allowed for backend '{backend.id}'" + ) + return __import__(name, globals, locals, fromlist, level) + + return _safe_import diff --git a/backend/requirements.txt b/backend/requirements.txt index 4feb984..5950419 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -3,10 +3,12 @@ uvicorn[standard] sqlmodel psycopg2-binary anthropic +claude-agent-sdk pydantic-ai[anthropic] policyengine-uk-compiled>=0.20.0 -policyengine_uk +policyengine_uk==2.88.0 pandas httpx supabase stripe +python-dateutil diff --git a/backend/routes/chatbot.py b/backend/routes/chatbot.py index fa1af54..3bc9d93 100644 --- a/backend/routes/chatbot.py +++ b/backend/routes/chatbot.py @@ -17,7 +17,8 @@ from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.settings import ModelSettings -from agent_tools import execute_tool, TOOL_DEFINITIONS +from agent_tools import execute_tool, get_tool_definitions +from model_backends import available_backends, get_backend logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ # --------------------------------------------------------------------------- # System prompt # --------------------------------------------------------------------------- -SYSTEM_PROMPT = """You are an expert policy analysis assistant for a UK microsimulation platform. You help users understand and analyse UK tax and benefit policy using reproducible Python code. +SYSTEM_PROMPT_TEMPLATE = """You are an expert policy analysis assistant for a microsimulation platform. You help users understand and analyse tax and benefit policy using reproducible Python code. CRITICAL - ALWAYS COMPUTE WITH PYTHON: - Never answer quantitative policy questions from memory. @@ -39,19 +40,7 @@ - Use that to ground yourself in the available datasets, years, programmes, and caveats before you simulate. - If the user asks about something outside the modelled scope, say so clearly instead of guessing. -CRITICAL - USE THE OFFICIAL POLICYENGINE PYTHON INTERFACE: -- The Python environment preloads: - `policyengine_uk_compiled` as `pe` - `Simulation` - `Parameters` - `StructuralReform` - `aggregate_microdata` - `combine_microdata` - `capabilities` - `ensure_dataset` - `pd`, `np`, `json`, `math` -- Prefer writing code directly against those objects so the run is reproducible outside chat. -- Do not recreate policy logic manually if the package already provides it. +{backend_prompt_context} REPRODUCIBILITY RULES: - Write clear Python that another developer could copy and run. @@ -61,10 +50,10 @@ - Do not rely on hidden reasoning for calculations when code can do the work. API AND DATASETS: -- A live API reference (docstrings, `capabilities()` snapshot, full `Parameters` JSON schema) is attached to this system prompt — consult it for signatures, reform keys, and dataset descriptions rather than guessing. +- A live API reference (docstrings, `capabilities()` snapshot, full `Parameters` JSON schema) is attached to this system prompt — consult it for signatures, reform keys, and dataset descriptions rather than guessing. The reference reflects the currently selected backend. - Call `capabilities()` at the start of a new line of analysis to check what's modelled and locally available before committing to an approach. - Tell the user which dataset you used when it matters. -- If something is not modelled well enough for a quantitative answer, say so clearly and do not fabricate estimates. +- If a dataset is unavailable in the selected backend, explain the limitation rather than fabricating estimates. ANALYTICAL NOTES: - Decile impacts are decile-level averages, not economy-wide means. @@ -80,11 +69,18 @@ """ +def _build_system_prompt(backend_id: str) -> str: + backend = get_backend(backend_id) + return SYSTEM_PROMPT_TEMPLATE.format( + backend_prompt_context=backend.prompt_context() + ) + + # --------------------------------------------------------------------------- # Pydantic-AI agent setup # --------------------------------------------------------------------------- -# We build the agent with tools dynamically from TOOL_DEFINITIONS +# We build the agent with tools dynamically per backend via get_tool_definitions. # pydantic-ai uses its own tool registration, but we'll drive it through # our own SSE loop using the underlying model API directly for streaming. @@ -123,18 +119,19 @@ def _get_sync_anthropic_client(): return anthropic_sdk.Anthropic(api_key=api_key) -def _tool_defs_for_anthropic(): - """Convert our TOOL_DEFINITIONS to Anthropic SDK format. +def _tool_defs_for_anthropic(backend_id: str): + """Convert backend tool definitions to Anthropic SDK format. Mark the last tool with cache_control so the system prompt + all tools are cached across requests (prompt caching).""" defs = [] - for i, t in enumerate(TOOL_DEFINITIONS): + tool_definitions = get_tool_definitions(backend_id) + for i, t in enumerate(tool_definitions): d = { "name": t["name"], "description": t["description"], "input_schema": t["input_schema"], } - if i == len(TOOL_DEFINITIONS) - 1: + if i == len(tool_definitions) - 1: d["cache_control"] = {"type": "ephemeral"} defs.append(d) return defs @@ -149,10 +146,10 @@ def _estimate_message_tokens(messages: List[dict]) -> int: return char_count // 4 -def _select_chat_model(messages: List[dict]) -> str: +def _select_chat_model(messages: List[dict], backend_id: str) -> str: estimated_input_tokens = ( _estimate_message_tokens(messages) - + len(SYSTEM_PROMPT) // 4 + + len(_build_system_prompt(backend_id)) // 4 + len(REFERENCE_DOC) // 4 ) if estimated_input_tokens > FAST_MODEL_MAX_INPUT_TOKENS: @@ -160,7 +157,7 @@ def _select_chat_model(messages: List[dict]) -> str: return DEFAULT_FAST_MODEL -def _build_system_blocks(plan_mode: bool = False) -> List[dict]: +def _build_system_blocks(backend_id: str, plan_mode: bool = False) -> List[dict]: """System prompt + cached library reference + optional plan-mode directive. The system prompt and reference are each marked with cache_control so they @@ -169,7 +166,7 @@ def _build_system_blocks(plan_mode: bool = False) -> List[dict]: """ blocks: List[dict] = [{ "type": "text", - "text": SYSTEM_PROMPT, + "text": _build_system_prompt(backend_id), "cache_control": {"type": "ephemeral"}, }] if REFERENCE_DOC: @@ -196,6 +193,7 @@ class ChatRequest(BaseModel): messages: List[ChatMessage] session_id: str | None = None user_id: str | None = None + model_backend: str | None = None plan_mode: bool = False @@ -239,6 +237,14 @@ def generate_title(request: TitleRequest): return {"title": response.content[0].text.strip()} +@router.get("/backends") +def list_backends(): + return { + "default": os.environ.get("POLICYENGINE_CHAT_BACKEND", "uk_compiled"), + "backends": available_backends(), + } + + # --------------------------------------------------------------------------- # Chat endpoint — SSE streaming # --------------------------------------------------------------------------- @@ -257,6 +263,13 @@ async def chat_message(request: ChatRequest, http_request: Request): pass # Supabase not configured — skip billing check session_id = request.session_id or str(uuid.uuid4()) + backend_id = request.model_backend or os.environ.get( + "POLICYENGINE_CHAT_BACKEND", "uk_compiled" + ) + try: + backend = get_backend(backend_id) + except ValueError as e: + return JSONResponse(status_code=400, content={"error": str(e)}) messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] @@ -268,6 +281,27 @@ async def chat_message(request: ChatRequest, http_request: Request): else: deduplicated[-1]["content"] += "\n\n" + msg["content"] + if os.environ.get("POLICYENGINE_CHAT_AGENT_RUNNER") == "claude_sdk": + from claude_agent_sdk_runner import generate_claude_agent_sdk_stream + + # The SDK runner takes a single system_prompt string. Plan mode is still + # enforced inside it (see runner). Cached reference doc is *not* sent + # through this path yet — the SDK runner is opt-in/experimental. + return StreamingResponse( + generate_claude_agent_sdk_stream( + conversation=deduplicated.copy(), + system_prompt=_build_system_prompt(backend.id) + + ("\n\n" + PLAN_MODE_DIRECTIVE if request.plan_mode else ""), + plan_mode=request.plan_mode, + session_id=session_id, + user_id=user_id, + backend_id=backend.id, + model=_select_chat_model(deduplicated, backend.id), + ), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + async def generate_stream(): try: conversation = deduplicated.copy() @@ -280,13 +314,13 @@ async def generate_stream(): recent_tool_calls: List[str] = [] client = _get_anthropic_client() - model = _select_chat_model(conversation) - tools = _tool_defs_for_anthropic() + model = _select_chat_model(conversation, backend.id) + tools = _tool_defs_for_anthropic(backend.id) plan_mode = request.plan_mode - system_blocks = _build_system_blocks(plan_mode=plan_mode) + system_blocks = _build_system_blocks(backend.id, plan_mode=plan_mode) logger.info( - f"[CHAT] Session {session_id}: {len(conversation)} messages" + f"[CHAT] Session {session_id}: {len(conversation)} messages, backend={backend.id}" f"{' [PLAN MODE]' if plan_mode else ''}" ) @@ -394,7 +428,7 @@ async def generate_stream(): ) except Exception as e: logger.warning(f"[CHAT] Failed to record usage: {e}") - yield f"data: {json.dumps({'type': 'done', 'content': assistant_content, 'session_id': session_id, 'model': model, 'usage': {'input_tokens': total_input_tokens, 'output_tokens': total_output_tokens, 'cache_creation_input_tokens': total_cache_creation_input_tokens, 'cache_read_input_tokens': total_cache_read_input_tokens}, 'cost_gbp': billing['cost_gbp'] if billing else None, 'balance': billing['balance'] if billing else None})}\n\n" + yield f"data: {json.dumps({'type': 'done', 'content': assistant_content, 'session_id': session_id, 'model': model, 'model_backend': backend.id, 'usage': {'input_tokens': total_input_tokens, 'output_tokens': total_output_tokens, 'cache_creation_input_tokens': total_cache_creation_input_tokens, 'cache_read_input_tokens': total_cache_read_input_tokens}, 'cost_gbp': billing['cost_gbp'] if billing else None, 'balance': billing['balance'] if billing else None})}\n\n" break # Detect infinite loops @@ -420,7 +454,9 @@ async def generate_stream(): async def execute_tool_async(tu): loop = asyncio.get_event_loop() logger.info(f"[CHAT] Starting tool: {tu['name']} input={tu['input']}") - result = await loop.run_in_executor(None, execute_tool, tu["name"], tu["input"]) + result = await loop.run_in_executor( + None, execute_tool, tu["name"], tu["input"], backend.id + ) logger.info(f"[CHAT] Finished tool: {tu['name']} result_keys={list(result.keys()) if isinstance(result, dict) else type(result)}") return tu, result @@ -475,7 +511,7 @@ async def execute_tool_async(tu): except Exception as e: logger.warning(f"[CHAT] Failed to record usage: {e}") yield f"data: {json.dumps({'type': 'chunk', 'content': '\\n\\n*[Reached maximum iterations]*'})}\n\n" - yield f"data: {json.dumps({'type': 'done', 'content': assistant_content, 'session_id': session_id, 'model': model, 'usage': {'input_tokens': total_input_tokens, 'output_tokens': total_output_tokens, 'cache_creation_input_tokens': total_cache_creation_input_tokens, 'cache_read_input_tokens': total_cache_read_input_tokens}, 'cost_gbp': billing['cost_gbp'] if billing else None, 'balance': billing['balance'] if billing else None})}\n\n" + yield f"data: {json.dumps({'type': 'done', 'content': assistant_content, 'session_id': session_id, 'model': model, 'model_backend': backend.id, 'usage': {'input_tokens': total_input_tokens, 'output_tokens': total_output_tokens, 'cache_creation_input_tokens': total_cache_creation_input_tokens, 'cache_read_input_tokens': total_cache_read_input_tokens}, 'cost_gbp': billing['cost_gbp'] if billing else None, 'balance': billing['balance'] if billing else None})}\n\n" except Exception as e: import traceback diff --git a/backend/tests/test_agent_tools.py b/backend/tests/test_agent_tools.py index 0740e52..773a316 100644 --- a/backend/tests/test_agent_tools.py +++ b/backend/tests/test_agent_tools.py @@ -11,10 +11,39 @@ compute, generate_chart, execute_tool, + get_tool_definitions, _build_compiled_policy, _json_safe, run_python, ) +from model_backends import available_backends, get_backend + + +# --------------------------------------------------------------------------- +# model backends +# --------------------------------------------------------------------------- + +class TestModelBackends: + def test_available_backends_include_compiled_and_python(self): + backends = available_backends() + assert "uk_compiled" in backends + assert "uk_python" in backends + assert backends["uk_compiled"]["package_label"] == "policyengine-uk-compiled" + assert backends["uk_python"]["package_label"] == "policyengine-uk" + assert "version" in backends["uk_compiled"] + assert "version" in backends["uk_python"] + + def test_backend_tool_descriptions_are_backend_specific(self): + compiled_tools = get_tool_definitions("uk_compiled") + python_tools = get_tool_definitions("uk_python") + assert compiled_tools[0]["name"] == "run_python" + assert python_tools[0]["name"] == "run_python" + assert "compiled" in compiled_tools[0]["description"] + assert "policyengine-uk" in python_tools[0]["description"] + + def test_unknown_backend_rejected(self): + with pytest.raises(ValueError): + get_backend("not_a_backend") # --------------------------------------------------------------------------- diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index b552c99..e47ee4a 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -23,6 +23,20 @@ def test_health(self): assert r.json()["status"] == "ok" +class TestChatBackends: + def test_lists_backends(self): + r = client.get("/chat/backends") + assert r.status_code == 200 + data = r.json() + assert data["default"] == "uk_compiled" + assert "uk_compiled" in data["backends"] + assert "uk_python" in data["backends"] + assert data["backends"]["uk_compiled"]["package_label"] == "policyengine-uk-compiled" + assert data["backends"]["uk_python"]["package_label"] == "policyengine-uk" + assert "version" in data["backends"]["uk_compiled"] + assert "version" in data["backends"]["uk_python"] + + # --------------------------------------------------------------------------- # Conversations CRUD # --------------------------------------------------------------------------- @@ -172,6 +186,14 @@ def parse_sse(response_text: str) -> list[dict]: class TestChatMessage: + def test_unknown_model_backend_returns_400(self): + r = client.post("/chat/message", json={ + "messages": [{"role": "user", "content": "hello"}], + "model_backend": "not_a_backend", + }) + assert r.status_code == 400 + assert "Unknown model backend" in r.json()["error"] + def test_simple_chat_returns_sse(self): with client.stream("POST", "/chat/message", json={ "messages": [{"role": "user", "content": "Say exactly: hello"}], @@ -267,18 +289,18 @@ class TestPlanMode: def test_directive_present_when_plan_mode_on(self): from routes.chatbot import _build_system_blocks, PLAN_MODE_DIRECTIVE - blocks = _build_system_blocks(plan_mode=True) + blocks = _build_system_blocks("uk_compiled", plan_mode=True) assert any(PLAN_MODE_DIRECTIVE in b.get("text", "") for b in blocks) def test_directive_absent_when_plan_mode_off(self): from routes.chatbot import _build_system_blocks, PLAN_MODE_DIRECTIVE - blocks = _build_system_blocks(plan_mode=False) + blocks = _build_system_blocks("uk_compiled", plan_mode=False) assert not any(PLAN_MODE_DIRECTIVE in b.get("text", "") for b in blocks) def test_base_prompt_cache_breakpoint_unchanged(self): from routes.chatbot import _build_system_blocks - on = _build_system_blocks(plan_mode=True) - off = _build_system_blocks(plan_mode=False) + on = _build_system_blocks("uk_compiled", plan_mode=True) + off = _build_system_blocks("uk_compiled", plan_mode=False) assert on[0] == off[0] assert "cache_control" in on[0] diff --git a/frontend/src/app/ChatPage.tsx b/frontend/src/app/ChatPage.tsx index b650092..bcef580 100644 --- a/frontend/src/app/ChatPage.tsx +++ b/frontend/src/app/ChatPage.tsx @@ -125,6 +125,29 @@ interface BalanceSummary { total_available_gbp: number; } +interface ModelBackendOption { + id: string; + display_name: string; + package_label: string; + version: string; +} + +interface ModelBackendsResponse { + default: string; + backends: Record; +} + +function formatBackendLabel(backend: ModelBackendOption): string { + if (backend.id === "uk_compiled") return "Compiled"; + if (backend.id === "uk_python") return "Python"; + return backend.display_name; +} + +function formatBackendVersion(backend: ModelBackendOption | undefined): string | null { + if (!backend?.package_label || !backend.version || backend.version === "unknown") return null; + return `${backend.package_label} v${backend.version}`; +} + async function apiRequest(method: string, endpoint: string, params?: Record, body?: unknown): Promise { const url = new URL(getBackendEndpoint(endpoint), window.location.origin); if (params) Object.entries(params).forEach(([k, v]) => url.searchParams.set(k, v)); @@ -163,18 +186,21 @@ export default function ChatPage() { const [reportNote, setReportNote] = useState(""); const [reportError, setReportError] = useState(null); const [reportSubmitting, setReportSubmitting] = useState(false); - const [planMode, setPlanMode] = useState(false); const scrollRef = useRef(null); const inputRef = useRef(null); const sessionId = useRef(null); const debugLog = useRef([]); const abortRef = useRef(null); - const [modelVersion, setModelVersion] = useState(null); + const [modelBackends, setModelBackends] = useState([]); + const [selectedBackendId, setSelectedBackendId] = useState("uk_compiled"); const [balance, setBalance] = useState(null); const [topUpLoading, setTopUpLoading] = useState(false); + const [planMode, setPlanMode] = useState(false); const hasMessages = messages.length > 0; const animatedPlaceholder = useAnimatedPlaceholder(EXAMPLE_QUERIES, !hasMessages && !input); + const selectedBackend = modelBackends.find((backend) => backend.id === selectedBackendId); + const selectedBackendVersion = formatBackendVersion(selectedBackend); const fetchBalance = useCallback(async () => { if (!user) return; @@ -193,8 +219,16 @@ export default function ChatPage() { }; useEffect(() => { - apiRequest<{ policyengine_uk_compiled: string }>("GET", "version") - .then((v) => setModelVersion(v.policyengine_uk_compiled)) + apiRequest("GET", "chat/backends") + .then((data) => { + const options = Object.values(data.backends); + const stored = window.localStorage.getItem("policyengine-chat-backend"); + const nextBackend = stored && options.some((backend) => backend.id === stored) + ? stored + : data.default; + setModelBackends(options); + setSelectedBackendId(nextBackend); + }) .catch(() => {}); // Refresh balance after Stripe redirect if (typeof window !== "undefined" && new URLSearchParams(window.location.search).get("topup") === "success") { @@ -302,7 +336,9 @@ export default function ChatPage() { return [{ id: saved.id, session_id: sid, title, created_at: saved.created_at, updated_at: saved.updated_at }, ...filtered]; }); return saved; - } catch (e) { console.error("Failed to save conversation", e); } + } catch { + return null; + } return null; }, [user]); @@ -419,7 +455,13 @@ export default function ChatPage() { const response = await fetch(getBackendEndpoint("chat/message"), { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ messages: apiMessages, session_id: sessionId.current, user_id: user?.id || null, plan_mode: planMode }), + body: JSON.stringify({ + messages: apiMessages, + session_id: sessionId.current, + user_id: user?.id || null, + model_backend: selectedBackendId, + plan_mode: planMode, + }), signal: controller.signal, }); if (response.status === 402) { @@ -545,6 +587,11 @@ export default function ChatPage() { const stopStreaming = () => { abortRef.current?.abort(); }; + const handleBackendChange = (backendId: string) => { + setSelectedBackendId(backendId); + window.localStorage.setItem("policyengine-chat-backend", backendId); + }; + const handleKeyDown = (e: React.KeyboardEvent) => { if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); sendMessage(); } }; @@ -979,36 +1026,72 @@ export default function ChatPage() { /> -
- - {!hasMessages && ( -
- Press Enter to send · Shift+Enter for new line - {modelVersion && policyengine-uk v{modelVersion}} -
- )} +
+
+ + {!hasMessages && Press Enter to send · Shift+Enter for new line} +
+
+ {modelBackends.length > 1 && ( +
+ Engine +
+ {modelBackends.map((backend) => { + const selected = backend.id === selectedBackendId; + return ( + + ); + })} +
+
+ )} + {selectedBackendVersion && {selectedBackendVersion}} +
diff --git a/frontend/src/app/layout.tsx b/frontend/src/app/layout.tsx index 6d41929..a3ed6a1 100644 --- a/frontend/src/app/layout.tsx +++ b/frontend/src/app/layout.tsx @@ -1,7 +1,10 @@ import type { Metadata } from "next"; +import Script from "next/script"; import "@mantine/core/styles.css"; import Providers from "./Providers"; +const GA_MEASUREMENT_ID = "G-2YHG89FY0N"; + export const metadata: Metadata = { title: "PolicyEngine UK", description: "UK tax and benefit microsimulation assistant", @@ -12,6 +15,19 @@ export default function RootLayout({ children }: { children: React.ReactNode }) return ( + diff --git a/modal_app.py b/modal_app.py index 5b91306..884afef 100644 --- a/modal_app.py +++ b/modal_app.py @@ -25,21 +25,7 @@ def _preload_engine(): image = ( modal.Image.debian_slim(python_version="3.13") .apt_install("libpq-dev", "gcc") - .pip_install( - "fastapi", - "uvicorn[standard]", - "sqlmodel", - "psycopg2-binary", - "anthropic", - "pydantic-ai[anthropic]", - "policyengine-uk-compiled>=0.20.0", - "policyengine_uk>=2.75.0", - "pandas", - "httpx", - "supabase", - "stripe", - "python-dateutil", - ) + .pip_install_from_requirements("backend/requirements.txt") .run_function(_preload_engine) .add_local_dir("backend", remote_path="/app/backend", copy=True) # Regenerate reference.md against the Modal-installed From ae5d32dd6859dacc4d160797f1045fc0d2d9027e Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 13 May 2026 00:51:19 -0700 Subject: [PATCH 2/7] Derive preview backend URL from VERCEL_GIT_COMMIT_REF, not hostname The hostname-based slug parsing breaks when Vercel truncates long preview hostnames (e.g. feat/model-backend-selector becomes policyengine-uk-chat-git-feat-model-backen-4022d4-policy-engine.vercel.app, yielding a Modal URL the deploy never created). Source the slug from the git ref instead, matching the CI workflow's derivation. Expose the env vars via next.config so they reach the client at build time. Co-Authored-By: Claude Opus 4.7 (1M context) --- frontend/next.config.js | 7 +++++++ frontend/src/utils/backend.ts | 23 +++++++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/frontend/next.config.js b/frontend/next.config.js index c10e07d..76eaa3f 100644 --- a/frontend/next.config.js +++ b/frontend/next.config.js @@ -1,6 +1,13 @@ /** @type {import('next').NextConfig} */ const nextConfig = { output: "standalone", + env: { + // Surface the git ref to the client so the preview can derive the matching + // Modal backend URL — Vercel's preview *hostname* gets truncated for long + // branch names, so we can't reliably parse it from window.location. + NEXT_PUBLIC_VERCEL_GIT_COMMIT_REF: process.env.VERCEL_GIT_COMMIT_REF, + NEXT_PUBLIC_VERCEL_ENV: process.env.VERCEL_ENV, + }, }; module.exports = nextConfig; diff --git a/frontend/src/utils/backend.ts b/frontend/src/utils/backend.ts index ce0bed1..3be2f5f 100644 --- a/frontend/src/utils/backend.ts +++ b/frontend/src/utils/backend.ts @@ -1,12 +1,19 @@ -function getPreviewBackendBase(): string | null { - if (typeof window === "undefined") return null; - - const match = window.location.hostname.match( - /^policyengine-uk-chat-git-(.+)-policy-engine\.vercel\.app$/, - ); - if (!match) return null; +function slugifyBranchName(value: string): string { + return value + .toLowerCase() + .replace(/[^a-z0-9]+/g, "-") + .replace(/^-+|-+$/g, "") + .replace(/-{2,}/g, "-"); +} - return `https://policyengine--peukchat-${match[1]}-web.modal.run`; +// Match the slug derivation in app/api/proxy/[...slug]/route.ts and in +// .github/workflows/pr-beta-deploy.yml — all three must agree on the Modal +// app name for a given branch. +function getPreviewBackendBase(): string | null { + if (process.env.NEXT_PUBLIC_VERCEL_ENV !== "preview") return null; + const gitRef = process.env.NEXT_PUBLIC_VERCEL_GIT_COMMIT_REF; + if (!gitRef) return null; + return `https://policyengine--peukchat-${slugifyBranchName(gitRef)}-web.modal.run`; } export function getBackendBase(): string { From 402867270db9b9529bd548c33bf778b0caeffcad Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 13 May 2026 01:02:45 -0700 Subject: [PATCH 3/7] Support HOSTNAME_REGEX for CORS; use it on preview deploys Vercel assigns multiple hostnames to the same preview deployment (the full one and a truncated-with-hash variant). Listing both in HOSTNAMES is brittle. Add a HOSTNAME_REGEX env var that overrides HOSTNAMES when set, and use a Vercel-preview-shaped regex in pr-beta-deploy.yml. Production deploys continue to use the explicit HOSTNAMES list. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/pr-beta-deploy.yml | 4 ++-- backend/main.py | 22 +++++++++++++++------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pr-beta-deploy.yml b/.github/workflows/pr-beta-deploy.yml index c4d4c80..91386ee 100644 --- a/.github/workflows/pr-beta-deploy.yml +++ b/.github/workflows/pr-beta-deploy.yml @@ -87,7 +87,7 @@ jobs: SUPABASE_SERVICE_ROLE_KEY="$SUPABASE_SERVICE_ROLE_KEY" \ STRIPE_SECRET_KEY="$STRIPE_SECRET_KEY" \ STRIPE_WEBHOOK_SECRET="$STRIPE_WEBHOOK_SECRET" \ - HOSTNAMES="https://policyengine-uk-chat.vercel.app" \ + HOSTNAME_REGEX="^https://policyengine-uk-chat-git-[a-z0-9-]+-policy-engine\.vercel\.app$" \ PUBLIC_BASE_URL="https://policyengine-uk-chat.vercel.app" \ --force @@ -131,7 +131,7 @@ jobs: SUPABASE_SERVICE_ROLE_KEY="$SUPABASE_SERVICE_ROLE_KEY" \ STRIPE_SECRET_KEY="$STRIPE_SECRET_KEY" \ STRIPE_WEBHOOK_SECRET="$STRIPE_WEBHOOK_SECRET" \ - HOSTNAMES="$FRONTEND_URL" \ + HOSTNAME_REGEX="^https://policyengine-uk-chat-git-[a-z0-9-]+-policy-engine\.vercel\.app$" \ PUBLIC_BASE_URL="$FRONTEND_URL" \ --force diff --git a/backend/main.py b/backend/main.py index c4fafdc..755ec5b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -18,6 +18,10 @@ _hostnames_env = os.environ.get("HOSTNAMES", "") HOSTNAMES = _hostnames_env.split(",") if _hostnames_env else ["*"] +# Optional regex used by preview deploys where Vercel assigns multiple +# hostnames per deployment (full + truncated-with-hash) and we can't +# enumerate them ahead of time. Setting this overrides HOSTNAMES. +HOSTNAME_REGEX = os.environ.get("HOSTNAME_REGEX") or None class NaNSafeJSONResponse(JSONResponse): @@ -41,13 +45,17 @@ def convert(obj): default_response_class=NaNSafeJSONResponse, ) -app.add_middleware( - CORSMiddleware, - allow_origins=HOSTNAMES, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +_cors_kwargs = { + "allow_credentials": True, + "allow_methods": ["*"], + "allow_headers": ["*"], +} +if HOSTNAME_REGEX: + _cors_kwargs["allow_origin_regex"] = HOSTNAME_REGEX +else: + _cors_kwargs["allow_origins"] = HOSTNAMES + +app.add_middleware(CORSMiddleware, **_cors_kwargs) app.include_router(billing.router) app.include_router(chatbot.router) From 3d8ec77f66d483cca4fed7fda323aa852ad14065 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 13 May 2026 01:26:28 -0700 Subject: [PATCH 4/7] Pipe HUGGING_FACE_TOKEN through to Modal secrets The uk_python backend reads HF datasets via policyengine-core's hugging_face downloader, which expects HUGGING_FACE_TOKEN in the env. Without it, microsim runs fail with 401 on any private dataset. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/deploy.yml | 2 ++ .github/workflows/pr-beta-deploy.yml | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 49dd254..e636861 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -33,6 +33,7 @@ jobs: MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} POLICYENGINE_UK_DATA_TOKEN: ${{ secrets.POLICYENGINE_UK_DATA_TOKEN }} + HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} DATABASE_URL: ${{ secrets.POLICYENGINE_UK_CHAT_DATABASE_URL }} SUPABASE_URL: ${{ secrets.SUPABASE_URL }} SUPABASE_SERVICE_ROLE_KEY: ${{ secrets.SUPABASE_SERVICE_ROLE_KEY }} @@ -46,6 +47,7 @@ jobs: ANTHROPIC_TITLE_MODEL="claude-haiku-4-5" \ ANTHROPIC_DEFAULT_MODEL="claude-haiku-4-5" \ POLICYENGINE_UK_DATA_TOKEN="$POLICYENGINE_UK_DATA_TOKEN" \ + HUGGING_FACE_TOKEN="$HUGGING_FACE_TOKEN" \ DATABASE_URL="$DATABASE_URL" \ SUPABASE_URL="$SUPABASE_URL" \ SUPABASE_SERVICE_ROLE_KEY="$SUPABASE_SERVICE_ROLE_KEY" \ diff --git a/.github/workflows/pr-beta-deploy.yml b/.github/workflows/pr-beta-deploy.yml index 91386ee..aee81d3 100644 --- a/.github/workflows/pr-beta-deploy.yml +++ b/.github/workflows/pr-beta-deploy.yml @@ -69,6 +69,7 @@ jobs: MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} POLICYENGINE_UK_DATA_TOKEN: ${{ secrets.POLICYENGINE_UK_DATA_TOKEN }} + HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} DATABASE_URL: ${{ secrets.POLICYENGINE_UK_CHAT_DATABASE_URL }} SUPABASE_URL: ${{ secrets.SUPABASE_URL }} SUPABASE_SERVICE_ROLE_KEY: ${{ secrets.SUPABASE_SERVICE_ROLE_KEY }} @@ -82,6 +83,7 @@ jobs: ANTHROPIC_TITLE_MODEL="claude-haiku-4-5" \ ANTHROPIC_DEFAULT_MODEL="claude-haiku-4-5" \ POLICYENGINE_UK_DATA_TOKEN="$POLICYENGINE_UK_DATA_TOKEN" \ + HUGGING_FACE_TOKEN="$HUGGING_FACE_TOKEN" \ DATABASE_URL="$DATABASE_URL" \ SUPABASE_URL="$SUPABASE_URL" \ SUPABASE_SERVICE_ROLE_KEY="$SUPABASE_SERVICE_ROLE_KEY" \ @@ -112,6 +114,7 @@ jobs: MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} POLICYENGINE_UK_DATA_TOKEN: ${{ secrets.POLICYENGINE_UK_DATA_TOKEN }} + HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} DATABASE_URL: ${{ secrets.POLICYENGINE_UK_CHAT_DATABASE_URL }} SUPABASE_URL: ${{ secrets.SUPABASE_URL }} SUPABASE_SERVICE_ROLE_KEY: ${{ secrets.SUPABASE_SERVICE_ROLE_KEY }} @@ -126,6 +129,7 @@ jobs: ANTHROPIC_TITLE_MODEL="claude-haiku-4-5" \ ANTHROPIC_DEFAULT_MODEL="claude-haiku-4-5" \ POLICYENGINE_UK_DATA_TOKEN="$POLICYENGINE_UK_DATA_TOKEN" \ + HUGGING_FACE_TOKEN="$HUGGING_FACE_TOKEN" \ DATABASE_URL="$DATABASE_URL" \ SUPABASE_URL="$SUPABASE_URL" \ SUPABASE_SERVICE_ROLE_KEY="$SUPABASE_SERVICE_ROLE_KEY" \ From ca145957ecdbbf7c7d027c8c25c734d4301530bd Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 13 May 2026 01:55:29 -0700 Subject: [PATCH 5/7] Accept scenario_context from embedders MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an optional scenario_context field to /chat/message and a matching URL query-param read in ChatPage. Embedders (e.g. app-v2 opening the chat in an iframe drawer next to a report) can seed the session with the scenario the user is already viewing. The context lives in its own system block *after* the cache breakpoints, so it never invalidates the cached system prompt or reference doc — same pattern used for plan mode. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/routes/chatbot.py | 40 ++++++++++++++++++++++++++------ backend/tests/test_api.py | 43 +++++++++++++++++++++++++++++++++++ frontend/src/app/ChatPage.tsx | 11 +++++++++ 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/backend/routes/chatbot.py b/backend/routes/chatbot.py index 3bc9d93..d091d6f 100644 --- a/backend/routes/chatbot.py +++ b/backend/routes/chatbot.py @@ -157,12 +157,18 @@ def _select_chat_model(messages: List[dict], backend_id: str) -> str: return DEFAULT_FAST_MODEL -def _build_system_blocks(backend_id: str, plan_mode: bool = False) -> List[dict]: - """System prompt + cached library reference + optional plan-mode directive. +def _build_system_blocks( + backend_id: str, + plan_mode: bool = False, + scenario_context: str | None = None, +) -> List[dict]: + """System prompt + cached library reference + optional plan-mode and + scenario-context directives. The system prompt and reference are each marked with cache_control so they - persist across requests. The plan-mode directive is appended AFTER both - cache breakpoints so toggling plan mode never invalidates cached blocks. + persist across requests. The plan-mode and scenario-context blocks are + appended AFTER both cache breakpoints so per-session/per-turn variation + never invalidates the cached blocks. """ blocks: List[dict] = [{ "type": "text", @@ -175,6 +181,11 @@ def _build_system_blocks(backend_id: str, plan_mode: bool = False) -> List[dict] "text": REFERENCE_DOC, "cache_control": {"type": "ephemeral"}, }) + if scenario_context: + blocks.append({ + "type": "text", + "text": f"SCENARIO CONTEXT FROM EMBEDDER:\n{scenario_context}", + }) if plan_mode: blocks.append({"type": "text", "text": PLAN_MODE_DIRECTIVE}) return blocks @@ -195,6 +206,11 @@ class ChatRequest(BaseModel): user_id: str | None = None model_backend: str | None = None plan_mode: bool = False + # Free-form context prepended to the system prompt for the calling + # session — used by embedders (e.g. app-v2) to seed the assistant with + # the scenario the user is already looking at. Lives in its own block + # *after* the cached breakpoints so it never invalidates the prompt cache. + scenario_context: str | None = None PLAN_MODE_DIRECTIVE = """ @@ -287,11 +303,17 @@ async def chat_message(request: ChatRequest, http_request: Request): # The SDK runner takes a single system_prompt string. Plan mode is still # enforced inside it (see runner). Cached reference doc is *not* sent # through this path yet — the SDK runner is opt-in/experimental. + sdk_system_prompt = _build_system_prompt(backend.id) + if request.scenario_context: + sdk_system_prompt += ( + f"\n\nSCENARIO CONTEXT FROM EMBEDDER:\n{request.scenario_context}" + ) + if request.plan_mode: + sdk_system_prompt += "\n\n" + PLAN_MODE_DIRECTIVE return StreamingResponse( generate_claude_agent_sdk_stream( conversation=deduplicated.copy(), - system_prompt=_build_system_prompt(backend.id) - + ("\n\n" + PLAN_MODE_DIRECTIVE if request.plan_mode else ""), + system_prompt=sdk_system_prompt, plan_mode=request.plan_mode, session_id=session_id, user_id=user_id, @@ -317,7 +339,11 @@ async def generate_stream(): model = _select_chat_model(conversation, backend.id) tools = _tool_defs_for_anthropic(backend.id) plan_mode = request.plan_mode - system_blocks = _build_system_blocks(backend.id, plan_mode=plan_mode) + system_blocks = _build_system_blocks( + backend.id, + plan_mode=plan_mode, + scenario_context=request.scenario_context, + ) logger.info( f"[CHAT] Session {session_id}: {len(conversation)} messages, backend={backend.id}" diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index e47ee4a..bdfca2c 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -310,3 +310,46 @@ def test_request_accepts_plan_mode_field(self): assert req.plan_mode is True req2 = ChatRequest(messages=[{"role": "user", "content": "hi"}]) assert req2.plan_mode is False + + +# --------------------------------------------------------------------------- +# Scenario context (embedder integration) +# --------------------------------------------------------------------------- + +class TestScenarioContext: + """Embedders (e.g. app-v2) pass a scenario_context string. It must reach + the system prompt as its own block, after the cache breakpoints, without + invalidating cached blocks.""" + + def test_block_present_when_context_set(self): + from routes.chatbot import _build_system_blocks + blocks = _build_system_blocks( + "uk_compiled", + scenario_context="User is viewing a UK PA-raise report.", + ) + assert any("User is viewing a UK PA-raise report." in b.get("text", "") for b in blocks) + + def test_block_absent_when_context_unset(self): + from routes.chatbot import _build_system_blocks + blocks = _build_system_blocks("uk_compiled") + assert not any("SCENARIO CONTEXT FROM EMBEDDER" in b.get("text", "") for b in blocks) + + def test_context_does_not_invalidate_cache_breakpoints(self): + from routes.chatbot import _build_system_blocks + with_ctx = _build_system_blocks("uk_compiled", scenario_context="hello") + without_ctx = _build_system_blocks("uk_compiled") + # Cached blocks (system prompt, reference doc) come first and must + # be identical so prompt caching stays warm across embedders. + cached_a = [b for b in with_ctx if b.get("cache_control")] + cached_b = [b for b in without_ctx if b.get("cache_control")] + assert cached_a == cached_b + + def test_request_accepts_scenario_context_field(self): + from routes.chatbot import ChatRequest + req = ChatRequest( + messages=[{"role": "user", "content": "hi"}], + scenario_context="some context", + ) + assert req.scenario_context == "some context" + req2 = ChatRequest(messages=[{"role": "user", "content": "hi"}]) + assert req2.scenario_context is None diff --git a/frontend/src/app/ChatPage.tsx b/frontend/src/app/ChatPage.tsx index bcef580..aa4a9ba 100644 --- a/frontend/src/app/ChatPage.tsx +++ b/frontend/src/app/ChatPage.tsx @@ -197,6 +197,10 @@ export default function ChatPage() { const [balance, setBalance] = useState(null); const [topUpLoading, setTopUpLoading] = useState(false); const [planMode, setPlanMode] = useState(false); + // Set once from URL query param (?scenario_context=...) so embedders like + // app-v2 can seed a chat session with the report they're already viewing. + // Read once on mount to avoid invalidating prompt cache mid-conversation. + const [scenarioContext, setScenarioContext] = useState(null); const hasMessages = messages.length > 0; const animatedPlaceholder = useAnimatedPlaceholder(EXAMPLE_QUERIES, !hasMessages && !input); const selectedBackend = modelBackends.find((backend) => backend.id === selectedBackendId); @@ -218,6 +222,12 @@ export default function ChatPage() { finally { setTopUpLoading(false); } }; + useEffect(() => { + if (typeof window === "undefined") return; + const ctx = new URLSearchParams(window.location.search).get("scenario_context"); + if (ctx) setScenarioContext(ctx); + }, []); + useEffect(() => { apiRequest("GET", "chat/backends") .then((data) => { @@ -461,6 +471,7 @@ export default function ChatPage() { user_id: user?.id || null, model_backend: selectedBackendId, plan_mode: planMode, + scenario_context: scenarioContext, }), signal: controller.signal, }); From de36e84781a94500e92b834d18aa28996ae68956 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Wed, 27 May 2026 08:40:45 -0700 Subject: [PATCH 6/7] Ignore .datasets/ and *.tsbuildinfo Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index b543491..7e8c56f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,12 @@ node_modules/ *.egg-info/ dist/ +# Local caches from running policyengine_uk against HF datasets. +.datasets/ + +# TypeScript incremental-build cache. +*.tsbuildinfo + # Generated at image build time by backend/scripts/build_reference.py. # Both backend/Dockerfile and modal_app.py regenerate this against the # installed policyengine-uk-compiled version, so it never lives in git. From 09aa6869af9be8dcecaa2c4c12b408ca36223ce3 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Mon, 1 Jun 2026 11:55:15 -0700 Subject: [PATCH 7/7] Pin policyengine_uk to 2.88.20 to match v1 API and eval fixtures Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/requirements.txt b/backend/requirements.txt index 5950419..b6b2642 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,7 +6,7 @@ anthropic claude-agent-sdk pydantic-ai[anthropic] policyengine-uk-compiled>=0.20.0 -policyengine_uk==2.88.0 +policyengine_uk==2.88.20 pandas httpx supabase