diff --git a/backend/agent_tools.py b/backend/agent_tools.py index e0c6392..e8992eb 100644 --- a/backend/agent_tools.py +++ b/backend/agent_tools.py @@ -1,218 +1,85 @@ """ Agent tools for the microsim chatbot. -Wraps compiled PolicyEngine UK simulations and utility operations. + +This module owns the public LLM-facing tool functions and dispatcher. +Tool schemas live in backend/tool_definitions.py; shared deterministic helpers +live under backend/tooling. """ -import hashlib import json import logging -from pathlib import Path -import sys from typing import Any, Dict, List, Optional -logger = logging.getLogger(__name__) - - -def _ensure_compiled_package_importable() -> None: - """Make the local policyengine_uk_compiled package importable in dev setups.""" - try: - import policyengine_uk_compiled # noqa: F401 - return - except ModuleNotFoundError: - pass - - repo_parent = Path(__file__).resolve().parents[2] - candidates = [ - repo_parent / "policyengine-uk-rust" / "interfaces" / "python", - repo_parent / "policyengine-uk-rust-codex-debug-issue" / "interfaces" / "python", - ] - for candidate in candidates: - if candidate.is_dir(): - candidate_str = str(candidate) - if candidate_str not in sys.path: - sys.path.insert(0, candidate_str) - try: - import policyengine_uk_compiled # noqa: F401 - return - except ModuleNotFoundError: - continue - - raise ModuleNotFoundError( - "policyengine_uk_compiled is not importable. Install the package or make sure a local " - "policyengine-uk-rust checkout with interfaces/python is available." - ) - -# --------------------------------------------------------------------------- -# Microdata cache -# --------------------------------------------------------------------------- -_microdata_cache: Dict[tuple, Any] = {} -_MAX_CACHE = 4 - - -def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): - allowed_roots = {"json", "math", "numpy", "pandas"} - root_name = name.split(".")[0] - if root_name not in allowed_roots: - raise ImportError(f"Import of '{name}' is not allowed") - return __import__(name, globals, locals, fromlist, level) - - -def _json_safe(obj: Any) -> Any: - try: - import numpy as np - except ImportError: - np = None - - try: - import pandas as pd - except ImportError: - pd = None - - if obj is None or isinstance(obj, (str, int, float, bool)): - return obj - if np is not None: - if isinstance(obj, np.ndarray): - return obj.tolist() - if isinstance(obj, np.integer): - return int(obj) - if isinstance(obj, np.floating): - return float(obj) - if isinstance(obj, np.bool_): - return bool(obj) - if pd is not None: - if isinstance(obj, pd.DataFrame): - return obj.to_dict(orient="records") - if isinstance(obj, pd.Series): - return obj.to_list() - if isinstance(obj, dict): - return {str(k): _json_safe(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple, set)): - return [_json_safe(v) for v in obj] - if hasattr(obj, "model_dump") and callable(obj.model_dump): - return _json_safe(obj.model_dump()) - if hasattr(obj, "dict") and callable(obj.dict): - return _json_safe(obj.dict()) - try: - import dataclasses - if dataclasses.is_dataclass(obj): - return _json_safe(dataclasses.asdict(obj)) - except Exception: - pass - return str(obj) - - -def _hash_reform(reform: Optional[Dict[str, Any]]) -> str: - if not reform: - return "none" - return hashlib.md5(json.dumps(reform, sort_keys=True).encode()).hexdigest() +from tool_definitions import TOOL_DEFINITIONS +from tooling.households import build_household_frames +from tooling.microdata import analyse_microdata_result, get_cached_microdata, hash_reform +from tooling.reforms import build_compiled_policy, validate_reform_dict +from tooling.sandbox import ( + build_structural_reform, + compile_structural_hook, + run_generator, + run_python_code, + safe_import, +) +from tooling.serialization import dataframe_to_records, explore_tabular_data, json_safe +from tooling.simulations import DATASET_LABELS, build_simulation, ensure_compiled_package_importable +logger = logging.getLogger(__name__) -def _get_cached_microdata(year: int, reform: Optional[Dict[str, Any]], dataset: str, structural=None): - """Return cached MicrodataResult. Structural reforms always run fresh.""" - if structural is not None: - policy = _build_compiled_policy(reform) - sim = _build_simulation(year, dataset) - return sim.run_microdata(policy=policy, structural=structural) - key = (year, _hash_reform(reform), dataset) - if key not in _microdata_cache: - policy = _build_compiled_policy(reform) - sim = _build_simulation(year, dataset) - _microdata_cache[key] = sim.run_microdata(policy=policy) - if len(_microdata_cache) > _MAX_CACHE: - del _microdata_cache[next(iter(_microdata_cache))] - return _microdata_cache[key] +# Compatibility aliases for tests and existing imports. They remain internal +# unless also listed in TOOL_DEFINITIONS and execute_tool(). +_ensure_compiled_package_importable = ensure_compiled_package_importable +_safe_import = safe_import +_json_safe = json_safe +_hash_reform = hash_reform +_get_cached_microdata = get_cached_microdata +_build_compiled_policy = build_compiled_policy +_build_simulation = build_simulation +_compile_structural_hook = compile_structural_hook +_build_structural_reform = build_structural_reform +_run_generator = run_generator + +__all__ = [ + "TOOL_DEFINITIONS", + "analyse_microdata", + "calculate_household", + "execute_tool", + "explore_tabular_data", + "generate_chart", + "get_baseline_parameters", + "get_capabilities", + "run_economy_simulation", + "run_python", + "validate_reform", +] def get_capabilities() -> Dict[str, Any]: try: _ensure_compiled_package_importable() from policyengine_uk_compiled import capabilities + return capabilities() - except Exception as e: - logger.error(f"Error getting capabilities: {e}") - return {"error": str(e)} - - -def explore_tabular_data(data: List[Dict[str, Any]], max_unique_values: int = 20) -> Dict[str, Any]: - if not data or not isinstance(data[0], dict): - return {"error": "Data must be a non-empty list of dicts", "row_count": 0, "columns": []} - row_count = len(data) - all_keys = set() - for row in data: - all_keys.update(row.keys()) - columns = [] - for key in sorted(all_keys): - values = [row.get(key) for row in data] - sample_type = next((type(v).__name__ for v in values if v is not None), "unknown") - unique_values = list(set(v for v in values if v is not None)) - unique_count = len(unique_values) - col_info = {"name": key, "type": sample_type, "unique_count": unique_count, "null_count": sum(1 for v in values if v is None)} - if unique_count <= max_unique_values: - try: - col_info["unique_values"] = sorted(unique_values) - except TypeError: - col_info["unique_values"] = unique_values - if sample_type in ("int", "float"): - numeric = [v for v in values if isinstance(v, (int, float))] - if numeric: - col_info["min"] = min(numeric) - col_info["max"] = max(numeric) - columns.append(col_info) - return {"row_count": row_count, "columns": columns} - - -def _build_compiled_policy(reform: Optional[Dict[str, Any]]): - if not reform: - return None - _ensure_compiled_package_importable() - from policyengine_uk_compiled import ( - Parameters, IncomeTaxParams, NationalInsuranceParams, UniversalCreditParams, - ChildBenefitParams, StatePensionParams, PensionCreditParams, BenefitCapParams, - HousingBenefitParams, TaxCreditsParams, ScottishChildPaymentParams, - StampDutyParams, StampDutyBand, CapitalGainsTaxParams, WealthTaxParams, - ) - param_cls_map = { - "income_tax": IncomeTaxParams, - "national_insurance": NationalInsuranceParams, - "universal_credit": UniversalCreditParams, - "child_benefit": ChildBenefitParams, - "state_pension": StatePensionParams, - "pension_credit": PensionCreditParams, - "benefit_cap": BenefitCapParams, - "housing_benefit": HousingBenefitParams, - "tax_credits": TaxCreditsParams, - "scottish_child_payment": ScottishChildPaymentParams, - "stamp_duty": StampDutyParams, - "capital_gains_tax": CapitalGainsTaxParams, - "wealth_tax": WealthTaxParams, - } - kwargs = {} - for program, fields in reform.items(): - if program not in param_cls_map: - raise ValueError(f"Unknown reform program '{program}'. Valid: {list(param_cls_map)}") - if not isinstance(fields, dict): - raise ValueError(f"Reform program '{program}' must be a dict, got {type(fields).__name__}") - cls = param_cls_map[program] - # stamp_duty bands is a list of dicts — convert to StampDutyBand objects - if cls is StampDutyParams and "bands" in fields and fields["bands"] is not None: - fields = {**fields, "bands": [StampDutyBand(**b) if isinstance(b, dict) else b for b in fields["bands"]]} - valid_fields = set(cls.model_fields) - unknown = {k for k in fields if k not in valid_fields and fields[k] is not None} - if unknown: - raise ValueError(f"Unknown field(s) for '{program}': {sorted(unknown)}. Valid: {sorted(valid_fields)}") - kwargs[program] = cls(**{k: v for k, v in fields.items() if v is not None}) - return Parameters(**kwargs) if kwargs else None + except Exception as exc: + logger.error(f"Error getting capabilities: {exc}") + return {"error": str(exc)} def get_baseline_parameters(year: int = 2025) -> Dict[str, Any]: try: _ensure_compiled_package_importable() from policyengine_uk_compiled import Simulation + sim = Simulation(year=year) return {"year": year, "parameters": sim.get_baseline_params()} - except Exception as e: - logger.error(f"Error getting baseline parameters: {e}") - return {"error": str(e)} + except Exception as exc: + logger.error(f"Error getting baseline parameters: {exc}") + return {"error": str(exc)} + + +def validate_reform(reform: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Validate parametric reform JSON without running a simulation.""" + return validate_reform_dict(reform) def calculate_household( @@ -223,152 +90,28 @@ def calculate_household( reform: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: try: - import pandas as pd - from policyengine_uk_compiled import Simulation, PERSON_DEFAULTS, BENUNIT_DEFAULTS, HOUSEHOLD_DEFAULTS - - def fill_defaults(records, defaults): - return pd.DataFrame([{**defaults, **rec} for rec in records]) - - # Remap IDs to 0-based (the compiled engine uses IDs as array indices) - hh_id_map = {rec["household_id"]: i for i, rec in enumerate(household)} - bu_id_map = {rec["benunit_id"]: i for i, rec in enumerate(benunit)} - person = [ - {**rec, "person_id": i, "benunit_id": bu_id_map[rec["benunit_id"]], "household_id": hh_id_map[rec["household_id"]]} - for i, rec in enumerate(person) - ] - benunit = [ - {**rec, "benunit_id": bu_id_map[rec["benunit_id"]], "household_id": hh_id_map[rec["household_id"]]} - for rec in benunit - ] - household = [ - {**rec, "household_id": hh_id_map[rec["household_id"]]} - for rec in household - ] - - # Set is_benunit_head/is_household_head: first adult (age>=16) per unit is head - seen_bu_heads = set() - seen_hh_heads = set() - for rec in person: - bu_id = rec["benunit_id"] - hh_id = rec["household_id"] - is_adult = rec.get("age", 30) >= 16 - rec["is_benunit_head"] = is_adult and bu_id not in seen_bu_heads - rec["is_household_head"] = is_adult and hh_id not in seen_hh_heads - if rec["is_benunit_head"]: - seen_bu_heads.add(bu_id) - if rec["is_household_head"]: - seen_hh_heads.add(hh_id) - - persons_df = fill_defaults(person, PERSON_DEFAULTS) - benunits_df = fill_defaults(benunit, BENUNIT_DEFAULTS) - households_df = fill_defaults(household, HOUSEHOLD_DEFAULTS) - - if "person_ids" not in benunits_df.columns or (benunits_df["person_ids"] == BENUNIT_DEFAULTS.get("person_ids", 0)).all(): - # Build comma-separated person_ids for each benunit from persons_df - bu_to_persons = persons_df.groupby("benunit_id")["person_id"].apply(lambda ids: ",".join(str(i) for i in ids)) - benunits_df["person_ids"] = benunits_df["benunit_id"].map(bu_to_persons).fillna(benunits_df["benunit_id"].astype(str)) - if "benunit_ids" not in households_df.columns or (households_df["benunit_ids"] == HOUSEHOLD_DEFAULTS.get("benunit_ids", 0)).all(): - hh_to_benunits = benunits_df.groupby("household_id")["benunit_id"].apply(lambda ids: ",".join(str(i) for i in ids)) - households_df["benunit_ids"] = households_df["household_id"].map(hh_to_benunits).fillna(households_df["household_id"].astype(str)) - if "person_ids" not in households_df.columns or (households_df["person_ids"] == HOUSEHOLD_DEFAULTS.get("person_ids", 0)).all(): - hh_to_persons = persons_df.groupby("household_id")["person_id"].apply(lambda ids: ",".join(str(i) for i in ids)) - households_df["person_ids"] = households_df["household_id"].map(hh_to_persons).fillna(households_df["household_id"].astype(str)) + _ensure_compiled_package_importable() + from policyengine_uk_compiled import Simulation + persons_df, benunits_df, households_df = build_household_frames(person, benunit, household) sim = Simulation(year=year, persons=persons_df, benunits=benunits_df, households=households_df) policy = _build_compiled_policy(reform) result = sim.run_microdata(policy=policy) - def df_to_records(df): - return [{k: (None if (hasattr(v, '__class__') and v.__class__.__name__ == 'float' and str(v) == 'nan') else v) for k, v in row.items()} for row in df.to_dict(orient="records")] - return { "status": "success", "year": year, "reform_applied": reform is not None, - "person": df_to_records(result.persons), - "benunit": df_to_records(result.benunits), - "household": df_to_records(result.households), + "person": dataframe_to_records(result.persons), + "benunit": dataframe_to_records(result.benunits), + "household": dataframe_to_records(result.households), } - except Exception as e: - logger.error(f"Error in calculate_household: {e}") - import traceback; logger.error(traceback.format_exc()) - return {"error": str(e)} -def _build_simulation(year: int, dataset: str = "frs"): - """Build a Simulation with the right data source and CLI flags.""" - _ensure_compiled_package_importable() - from policyengine_uk_compiled import Simulation - return Simulation(year=year, dataset=dataset) - - -def _compile_structural_hook(code: str): - """Compile a structural hook from code defining `hook(...)`. - - The hook signature must be: - hook(year, persons, benunits, households) -> (persons, benunits, households) - """ - import math - import builtins as _builtins - - safe_names = ( - "range", "len", "int", "float", "str", "bool", "list", "dict", - "tuple", "set", "zip", "enumerate", "map", "filter", "sorted", - "reversed", "min", "max", "sum", "abs", "round", "True", "False", - "None", "isinstance", "ValueError", "TypeError", "print", - "any", "all", "pow", "divmod", - ) - safe_builtins = {k: getattr(_builtins, k) for k in safe_names if hasattr(_builtins, k)} + except Exception as exc: + logger.error(f"Error in calculate_household: {exc}") + import traceback - try: - import numpy as np - except ImportError: - np = None - - try: - import pandas as pd - except ImportError as e: - raise ImportError("pandas is required for structural reform hooks") from e - - allowed_globals: Dict[str, Any] = { - "__builtins__": safe_builtins, - "math": math, - "json": json, - "pd": pd, - } - if np is not None: - allowed_globals["np"] = np - allowed_globals["numpy"] = np - - exec(code, allowed_globals) - hook = allowed_globals.get("hook") - if hook is None or not callable(hook): - raise ValueError("Structural hook code must define a callable `hook(year, persons, benunits, households)`") - return hook - - -def _build_structural_reform(structural_reform: Optional[Dict[str, Any]]): - if not structural_reform: - return None - if not isinstance(structural_reform, dict): - raise ValueError(f"structural_reform must be a dict, got {type(structural_reform).__name__}") - - unknown = set(structural_reform) - {"pre", "post"} - if unknown: - raise ValueError(f"Unknown structural_reform field(s): {sorted(unknown)}. Valid: ['pre', 'post']") - - _ensure_compiled_package_importable() - from policyengine_uk_compiled import StructuralReform - - pre = structural_reform.get("pre") - post = structural_reform.get("post") - if pre is not None and not isinstance(pre, str): - raise ValueError("structural_reform.pre must be a string of Python code defining hook(...)") - if post is not None and not isinstance(post, str): - raise ValueError("structural_reform.post must be a string of Python code defining hook(...)") - - return StructuralReform( - pre=_compile_structural_hook(pre) if pre else None, - post=_compile_structural_hook(post) if post else None, - ) + logger.error(traceback.format_exc()) + return {"error": str(exc)} def run_economy_simulation( @@ -381,10 +124,10 @@ def run_economy_simulation( policy = _build_compiled_policy(reform) structural = _build_structural_reform(structural_reform) sim = _build_simulation(year, dataset) - # Always run baseline to compute program-level changes baseline_result = sim.run() if structural is not None: from policyengine_uk_compiled import aggregate_microdata, combine_microdata + baseline_microdata = sim.run_microdata() reform_microdata = sim.run_microdata(policy=policy, structural=structural) combined_microdata = combine_microdata(baseline_microdata, reform_microdata) @@ -400,14 +143,17 @@ def run_economy_simulation( baseline_breakdown = baseline_result.program_breakdown.model_dump() reform_breakdown = reform_result.program_breakdown.model_dump() program_changes = { - k: {"baseline": baseline_breakdown[k], "reform": reform_breakdown[k], "change": reform_breakdown[k] - baseline_breakdown[k]} - for k in baseline_breakdown + key: { + "baseline": baseline_breakdown[key], + "reform": reform_breakdown[key], + "change": reform_breakdown[key] - baseline_breakdown[key], + } + for key in baseline_breakdown } - dataset_labels = {"frs": "Family Resources Survey", "efrs": "Enhanced FRS", "spi": "Survey of Personal Incomes", "lcfs": "Living Costs and Food Survey", "was": "Wealth and Assets Survey"} return { "fiscal_year": reform_result.fiscal_year, - "dataset": dataset_labels.get(dataset, dataset), + "dataset": DATASET_LABELS.get(dataset, dataset), "budgetary_impact": reform_result.budgetary_impact.model_dump(), "program_breakdown_changes": program_changes, "decile_impacts": [d.model_dump() for d in reform_result.decile_impacts], @@ -419,12 +165,18 @@ def run_economy_simulation( "reform_poverty": reform_result.reform_poverty.model_dump(), "structural_reform_applied": structural is not None, } - except FileNotFoundError as e: - return {"error": f"{dataset.upper()} microdata not available", "detail": str(e), "hint": "Ensure POLICYENGINE_UK_DATA_TOKEN is set."} - except Exception as e: - logger.error(f"Error in run_economy_simulation: {e}") - import traceback; logger.error(traceback.format_exc()) - return {"error": str(e)} + except FileNotFoundError as exc: + return { + "error": f"{dataset.upper()} microdata not available", + "detail": str(exc), + "hint": "Ensure POLICYENGINE_UK_DATA_TOKEN is set.", + } + except Exception as exc: + logger.error(f"Error in run_economy_simulation: {exc}") + import traceback + + logger.error(traceback.format_exc()) + return {"error": str(exc)} def analyse_microdata( @@ -437,362 +189,154 @@ def analyse_microdata( columns: Optional[List[str]] = None, group_by: Optional[List[str]] = None, n: int = 5, - dataset: str = "frs", + dataset: str = "efrs", ) -> Dict[str, Any]: try: - import pandas as pd + dataset_key = (dataset or "").lower() + if dataset_key == "frs": + return { + "error": "analyse_microdata does not support FRS row-level access", + "hint": ( + "Use run_economy_simulation for aggregate FRS outputs, or choose " + "a non-FRS dataset for analyse_microdata." + ), + } policy = _build_compiled_policy(reform) structural = _build_structural_reform(structural_reform) if structural is not None: from policyengine_uk_compiled import combine_microdata - sim = _build_simulation(year, dataset) + + sim = _build_simulation(year, dataset_key) baseline_microdata = sim.run_microdata() reform_microdata = sim.run_microdata(policy=policy, structural=structural) microdata = combine_microdata(baseline_microdata, reform_microdata) else: - microdata = _get_cached_microdata(year, reform, dataset) - - entity_map = {"persons": microdata.persons, "benunits": microdata.benunits, "households": microdata.households} - if entity not in entity_map: - return {"error": f"entity must be one of: persons, benunits, households"} - df = entity_map[entity].copy() - - weights = microdata.households[["household_id", "weight"]].copy() - if "household_id" in df.columns and "weight" not in df.columns: - df = df.merge(weights, on="household_id", how="left") - elif "weight" not in df.columns: - df["weight"] = 1.0 - - change_pairs = { - "persons": [("income_tax", "baseline_income_tax", "reform_income_tax"), ("employee_ni", "baseline_employee_ni", "reform_employee_ni"), ("total_income", "baseline_total_income", "reform_total_income")], - "benunits": [("total_benefits", "baseline_total_benefits", "reform_total_benefits"), ("universal_credit", "baseline_universal_credit", "reform_universal_credit"), ("child_benefit", "baseline_child_benefit", "reform_child_benefit")], - "households": [("net_income", "baseline_net_income", "reform_net_income"), ("total_tax", "baseline_total_tax", "reform_total_tax"), ("total_benefits", "baseline_total_benefits", "reform_total_benefits")], - } - for change_col, base_col, ref_col in change_pairs.get(entity, []): - if base_col in df.columns and ref_col in df.columns: - df[f"{change_col}_change"] = df[ref_col] - df[base_col] - - filters_applied = {} - if filters: - for col, fval in filters.items(): - if col not in df.columns: - return {"error": f"Column '{col}' not found. Available: {list(df.columns)}"} - filters_applied[col] = fval - if isinstance(fval, dict): - if "min" in fval: df = df[df[col] >= fval["min"]] - if "max" in fval: df = df[df[col] <= fval["max"]] - if "gt" in fval: df = df[df[col] > fval["gt"]] - if "lt" in fval: df = df[df[col] < fval["lt"]] - if "gte" in fval: df = df[df[col] >= fval["gte"]] - if "lte" in fval: df = df[df[col] <= fval["lte"]] - if "ne" in fval: df = df[df[col] != fval["ne"]] - elif isinstance(fval, list): - df = df[df[col].isin(fval)] - else: - df = df[df[col] == fval] - - row_count = len(df) - weighted_count = int(df["weight"].sum()) if "weight" in df.columns else row_count - all_cols = list(df.columns) - - if columns: - missing = [c for c in columns if c not in df.columns] - if missing: - return {"error": f"Columns not found: {missing}. Available: {all_cols}"} - value_cols = columns - else: - if entity == "persons": - value_cols = ["age", "gender", "employment_income", "self_employment_income", "baseline_income_tax", "reform_income_tax", "income_tax_change", "baseline_total_income", "reform_total_income", "total_income_change"] - elif entity == "benunits": - value_cols = ["baseline_total_benefits", "reform_total_benefits", "total_benefits_change", "baseline_universal_credit", "reform_universal_credit", "baseline_child_benefit", "reform_child_benefit"] - else: - value_cols = ["region", "baseline_net_income", "reform_net_income", "net_income_change", "baseline_total_tax", "reform_total_tax", "baseline_total_benefits", "reform_total_benefits"] - value_cols = [c for c in value_cols if c in df.columns] - - if operation == "sample": - actual_n = min(n, 20, row_count) - sample_df = df[value_cols].sample(n=actual_n, random_state=42) if row_count >= actual_n else df[value_cols] - result = [{k: (None if (isinstance(v, float) and str(v) == "nan") else v) for k, v in row.items()} for row in sample_df.to_dict(orient="records")] - elif operation == "mean": - numeric_cols = [c for c in value_cols if pd.api.types.is_numeric_dtype(df[c]) and c != "weight"] - result = {c: float((df[c] * df["weight"]).sum() / df["weight"].sum()) if df["weight"].sum() > 0 else float(df[c].mean()) for c in numeric_cols} - elif operation == "sum": - numeric_cols = [c for c in value_cols if pd.api.types.is_numeric_dtype(df[c]) and c != "weight"] - result = {c: float((df[c] * df["weight"]).sum()) for c in numeric_cols} - elif operation == "count": - result = {"row_count": row_count, "weighted_population": weighted_count} - elif operation == "describe": - numeric_cols = [c for c in value_cols if pd.api.types.is_numeric_dtype(df[c]) and c != "weight"] - result = {c: {"mean": float((df[c] * df["weight"]).sum() / df["weight"].sum()) if df["weight"].sum() > 0 else float(df[c].mean()), "min": float(df[c].min()), "max": float(df[c].max()), "count": int(df[c].count())} for c in numeric_cols} - for c in [c for c in value_cols if not pd.api.types.is_numeric_dtype(df[c])]: - result[c] = {str(k): int(v) for k, v in df[c].value_counts().head(10).items()} - else: - return {"error": f"Unknown operation '{operation}'. Use: mean, sum, count, sample, describe"} - - dataset_labels = {"frs": "Family Resources Survey", "efrs": "Enhanced FRS", "spi": "Survey of Personal Incomes", "lcfs": "Living Costs and Food Survey", "was": "Wealth and Assets Survey"} - return {"entity": entity, "operation": operation, "year": year, "dataset": dataset_labels.get(dataset, dataset), "reform_applied": reform is not None, "structural_reform_applied": structural is not None, "filters_applied": filters_applied, "row_count": row_count, "weighted_count": weighted_count, "result": result, "available_columns": all_cols} - except Exception as e: - logger.error(f"Error in analyse_microdata: {e}") - import traceback; logger.error(traceback.format_exc()) - return {"error": str(e)} + microdata = _get_cached_microdata(year, reform, dataset_key) + + return analyse_microdata_result( + microdata=microdata, + entity=entity, + operation=operation, + year=year, + dataset_key=dataset_key, + reform_applied=reform is not None, + structural_reform_applied=structural is not None, + filters=filters, + columns=columns, + group_by=group_by, + n=n, + ) + except Exception as exc: + logger.error(f"Error in analyse_microdata: {exc}") + import traceback + + logger.error(traceback.format_exc()) + return {"error": str(exc)} def generate_chart( - chart_type: str, title: str, data: List[Dict[str, Any]], x_field: str, y_fields: List[str], - x_label: Optional[str] = None, y_label: Optional[str] = None, - x_format: Optional[str] = None, y_format: Optional[str] = None, - x_min: Optional[float] = None, x_max: Optional[float] = None, - y_min: Optional[float] = None, y_max: Optional[float] = None, - series_labels: Optional[List[str]] = None, series_styles: Optional[List[str]] = None, - series_curves: Optional[List[str]] = None, subtitle: Optional[str] = None, - source: Optional[str] = None, arrangement: Optional[str] = None, area_fill: Optional[bool] = None, + chart_type: str, + title: str, + data: List[Dict[str, Any]], + x_field: str, + y_fields: List[str], + x_label: Optional[str] = None, + y_label: Optional[str] = None, + x_format: Optional[str] = None, + y_format: Optional[str] = None, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + y_min: Optional[float] = None, + y_max: Optional[float] = None, + series_labels: Optional[List[str]] = None, + series_styles: Optional[List[str]] = None, + series_curves: Optional[List[str]] = None, + subtitle: Optional[str] = None, + source: Optional[str] = None, + arrangement: Optional[str] = None, + area_fill: Optional[bool] = None, ) -> Dict[str, Any]: try: series = [] - for i, y_field in enumerate(y_fields): - s = {"field": y_field, "label": series_labels[i] if series_labels and i < len(series_labels) else y_field} - if series_styles and i < len(series_styles): s["lineStyle"] = series_styles[i] - if series_curves and i < len(series_curves): s["curve"] = series_curves[i] - series.append(s) + for index, y_field in enumerate(y_fields): + item = {"field": y_field, "label": series_labels[index] if series_labels and index < len(series_labels) else y_field} + if series_styles and index < len(series_styles): + item["lineStyle"] = series_styles[index] + if series_curves and index < len(series_curves): + item["curve"] = series_curves[index] + series.append(item) spec = { - "type": chart_type, "title": title, + "type": chart_type, + "title": title, "x": {"field": x_field, "label": x_label or x_field}, - "y": {"field": y_fields[0] if len(y_fields) == 1 else "value", "label": y_label or (y_fields[0] if len(y_fields) == 1 else "Value")}, - "series": series, "data": data, "showLegend": len(y_fields) > 1, "showGrid": True, + "y": { + "field": y_fields[0] if len(y_fields) == 1 else "value", + "label": y_label or (y_fields[0] if len(y_fields) == 1 else "Value"), + }, + "series": series, + "data": data, + "showLegend": len(y_fields) > 1, + "showGrid": True, + } + if x_format: + spec["x"]["format"] = x_format + if y_format: + spec["y"]["format"] = y_format + if x_min is not None: + spec["x"]["min"] = x_min + if x_max is not None: + spec["x"]["max"] = x_max + if y_min is not None: + spec["y"]["min"] = y_min + if y_max is not None: + spec["y"]["max"] = y_max + if subtitle: + spec["subtitle"] = subtitle + if source: + spec["source"] = source + if arrangement and chart_type == "bar": + spec["arrangement"] = arrangement + if area_fill and chart_type == "line": + spec["areaFill"] = area_fill + + return { + "status": "success", + "chart_markdown": f"```chart\n{json.dumps(spec, indent=2)}\n```", + "message": "Chart generated. Include the chart_markdown in your response to display it.", } - if x_format: spec["x"]["format"] = x_format - if y_format: spec["y"]["format"] = y_format - if x_min is not None: spec["x"]["min"] = x_min - if x_max is not None: spec["x"]["max"] = x_max - if y_min is not None: spec["y"]["min"] = y_min - if y_max is not None: spec["y"]["max"] = y_max - if subtitle: spec["subtitle"] = subtitle - if source: spec["source"] = source - if arrangement and chart_type == "bar": spec["arrangement"] = arrangement - if area_fill and chart_type == "line": spec["areaFill"] = area_fill - - return {"status": "success", "chart_markdown": f"```chart\n{json.dumps(spec, indent=2)}\n```", "message": "Chart generated. Include the chart_markdown in your response to display it."} - except Exception as e: - return {"error": str(e)} + except Exception as exc: + return {"error": str(exc)} def run_python(code: str) -> Dict[str, Any]: - """Execute Python code with the PolicyEngine UK compiled interface preloaded. - - The code should assign its final result to a variable called `result`. - The environment includes the official Python wrapper so runs are easy to - reproduce outside the chat app. - """ - import math - import builtins as _builtins - _ensure_compiled_package_importable() - import pandas as pd - import policyengine_uk_compiled as pe - - from policyengine_uk_compiled import ( - Simulation, - StructuralReform, - Parameters, - aggregate_microdata, - combine_microdata, - capabilities, - ensure_dataset, - ) - - safe_names = ( - "range", "len", "int", "float", "str", "bool", "list", "dict", - "tuple", "set", "zip", "enumerate", "map", "filter", "sorted", - "reversed", "min", "max", "sum", "abs", "round", "True", "False", - "None", "isinstance", "ValueError", "TypeError", "Exception", - "print", "any", "all", "pow", "divmod", "complex", "type", - "dir", "hasattr", "getattr", - ) - safe_builtins = {k: getattr(_builtins, k) for k in safe_names if hasattr(_builtins, k)} + """Execute Python code with the PolicyEngine UK compiled interface preloaded.""" + return run_python_code(code) - try: - import numpy as np - except ImportError: - np = None - - output_lines: List[str] = [] - def safe_print(*args, **kwargs): - output_lines.append(" ".join(str(a) for a in args)) - - safe_builtins["print"] = safe_print - safe_builtins["__import__"] = _safe_import - - allowed_globals: Dict[str, Any] = { - "__builtins__": safe_builtins, - "math": math, - "json": json, - "pd": pd, - "pe": pe, - "Simulation": Simulation, - "StructuralReform": StructuralReform, - "Parameters": Parameters, - "aggregate_microdata": aggregate_microdata, - "combine_microdata": combine_microdata, - "capabilities": capabilities, - "ensure_dataset": ensure_dataset, - } - if np is not None: - allowed_globals["np"] = np - allowed_globals["numpy"] = np - try: - exec(code, allowed_globals) - except Exception as e: - return {"error": f"{type(e).__name__}: {e}"} - - result = allowed_globals.get("result", None) - - response: Dict[str, Any] = {} - if result is not None: - response["result"] = _json_safe(result) - if output_lines: - response["output"] = "\n".join(output_lines) - if not response: - response["result"] = None - response["note"] = "No 'result' variable was set and nothing was printed." - return response - - - -def _run_generator(code: str) -> Dict[str, Any]: - """Execute a Python generator snippet that returns a dict of tool kwargs. - - The code must define a `generate()` function that returns a dict. - Only safe builtins + math are available — no file/network/import access. - """ - import math - import builtins as _builtins - safe_names = ( - "range", "len", "int", "float", "str", "bool", "list", "dict", - "tuple", "set", "zip", "enumerate", "map", "filter", "sorted", - "reversed", "min", "max", "sum", "abs", "round", "True", "False", - "None", "isinstance", "ValueError", "TypeError", "append", - ) - safe_builtins = {k: getattr(_builtins, k) for k in safe_names if hasattr(_builtins, k)} - allowed_globals: Dict[str, Any] = {"__builtins__": safe_builtins, "math": math, "json": json} - exec(code, allowed_globals) - if "generate" not in allowed_globals: - raise ValueError("Generator code must define a `generate()` function") - result = allowed_globals["generate"]() - if not isinstance(result, dict): - raise ValueError(f"generate() must return a dict, got {type(result).__name__}") - return result +TOOL_HANDLERS = { + "validate_reform": validate_reform, + "calculate_household": calculate_household, + "run_economy_simulation": run_economy_simulation, + "analyse_microdata": analyse_microdata, + "run_python": run_python, + "generate_chart": generate_chart, +} def execute_tool(tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]: logger.info(f"[TOOLS] Executing {tool_name}") - tools = { - "run_python": run_python, - "generate_chart": generate_chart, - } - if tool_name not in tools: + if tool_name not in TOOL_HANDLERS: return {"error": f"Unknown tool: {tool_name}"} try: - # If input contains a generator, execute it to produce the real kwargs if "generator" in tool_input: logger.info(f"[TOOLS] Running generator for {tool_name}") tool_input = _run_generator(tool_input["generator"]) logger.info(f"[TOOLS] Generator produced keys: {list(tool_input.keys())}") - result = tools[tool_name](**tool_input) + result = TOOL_HANDLERS[tool_name](**tool_input) logger.info(f"[TOOLS] {tool_name} completed") return result - except Exception as e: - logger.error(f"[TOOLS] Error in {tool_name}: {e}") - return {"error": str(e)} - - -TOOL_DEFINITIONS = [ - { - "name": "run_python", - "description": "Execute reproducible Python code using the official PolicyEngine UK compiled interface. The environment preloads `policyengine_uk_compiled` as `pe`, plus `Simulation`, `Parameters`, `StructuralReform`, `aggregate_microdata`, `combine_microdata`, `capabilities`, `ensure_dataset`, `pd`, `np`, `json`, and `math`. Assign the final answer to `result` and use `print()` for intermediate output. Do not inspect or return row-level survey microdata. For household examples, create illustrative synthetic households, prefer `Simulation.single_person()` for single-person examples, and label them as illustrative rather than real households.", - "input_schema": { - "type": "object", - "properties": { - "code": {"type": "string", "description": "Python code to execute. Must assign the final answer to `result`. Use the preloaded PolicyEngine interface directly, for example: `sim = Simulation(year=2025)` or `policy = Parameters.model_validate({...})`."}, - }, - "required": ["code"], - }, - }, - { - "name": "generate_chart", - "description": ( - "Generate a chart JSON block for the frontend to render. " - "Use this for visualisations such as income distributions, marginal-rate or tax-schedule curves, " - "decile impact comparisons, and trends over time or income. " - "Use factually neutral titles, subtitles, labels, and captions; do not call policies good, bad, fair, unfair, " - "regressive, progressive, generous, or punitive. " - "The tool returns a `chart_markdown` field containing a ```chart fenced JSON block — you MUST paste that " - "string verbatim into your next text response, otherwise the chart will not appear to the user. " - "Do not attempt to render charts with matplotlib inside `run_python`; the UI cannot display matplotlib output. " - "Compute the data first with `run_python` (returning a list of row dicts), then pass it to this tool." - ), - "input_schema": { - "type": "object", - "properties": { - "chart_type": { - "type": "string", - "enum": ["line", "bar", "area", "scatter"], - "description": "Chart type. Use `line` for schedules/curves over a continuous x, `bar` for category comparisons (e.g. deciles), `area` for stacked compositions, `scatter` for point clouds.", - }, - "title": {"type": "string", "description": "Factually neutral chart title shown above the plot."}, - "data": { - "type": "array", - "description": "List of row objects. Each row must contain the `x_field` key and every key listed in `y_fields`.", - "items": {"type": "object"}, - }, - "x_field": {"type": "string", "description": "Key in each data row to use as the x value."}, - "y_fields": { - "type": "array", - "description": "Keys in each data row to plot as y series. Provide multiple for multi-series charts (e.g. baseline vs reform).", - "items": {"type": "string"}, - }, - "x_label": {"type": "string", "description": "Axis label for x (defaults to `x_field`)."}, - "y_label": {"type": "string", "description": "Axis label for y (defaults to first y field or 'Value')."}, - "x_format": { - "type": "string", - "enum": ["currency", "percent", "percent_decimal", "number", "compact", "year"], - "description": "Number format for x-axis ticks and tooltips. Use `currency` for £ amounts, `percent` for values already on a 0–100 scale, `percent_decimal` for 0–1 shares, `compact` for large counts (1.2k), `year` for calendar years.", - }, - "y_format": { - "type": "string", - "enum": ["currency", "percent", "percent_decimal", "number", "compact", "year"], - "description": "Number format for y-axis ticks and tooltips. Same options as `x_format`.", - }, - "x_min": {"type": "number", "description": "Optional fixed minimum for the x axis."}, - "x_max": {"type": "number", "description": "Optional fixed maximum for the x axis."}, - "y_min": {"type": "number", "description": "Optional fixed minimum for the y axis."}, - "y_max": {"type": "number", "description": "Optional fixed maximum for the y axis."}, - "series_labels": { - "type": "array", - "description": "Display labels for each y series, in the same order as `y_fields`.", - "items": {"type": "string"}, - }, - "series_styles": { - "type": "array", - "description": "Line style per series (line/area charts).", - "items": {"type": "string", "enum": ["solid", "dashed", "dotted"]}, - }, - "series_curves": { - "type": "array", - "description": "Curve interpolation per series (line/area charts).", - "items": {"type": "string", "enum": ["smooth", "step", "linear"]}, - }, - "subtitle": {"type": "string", "description": "Optional subtitle shown under the title."}, - "source": {"type": "string", "description": "Optional source/caption shown beneath the chart."}, - "arrangement": { - "type": "string", - "enum": ["grouped", "stacked"], - "description": "For bar charts only: `grouped` side-by-side or `stacked`.", - }, - "area_fill": {"type": "boolean", "description": "For line charts only: fill the area under the line."}, - }, - "required": ["chart_type", "title", "data", "x_field", "y_fields"], - }, - }, -] + except Exception as exc: + logger.error(f"[TOOLS] Error in {tool_name}: {exc}") + return {"error": str(exc)} diff --git a/backend/prompts.py b/backend/prompts.py index 0833af0..468ebfc 100644 --- a/backend/prompts.py +++ b/backend/prompts.py @@ -14,19 +14,26 @@ """ PYTHON_COMPUTATION_RULES = """ -CRITICAL - ALWAYS COMPUTE WITH PYTHON: +CRITICAL - ALWAYS COMPUTE WITH TOOLS: - Never answer quantitative policy questions from memory. -- You have one execution tool: `run_python`. -- Use `run_python` for every tax, benefit, reform, schedule, poverty, decile, - and distributional question. -- Every number in your answer must come directly from the Python result you - just computed. +- Every number in your answer must come directly from a tool result you just + computed. +- Prefer the typed calculation tools when the question fits their shape: + `calculate_household` for illustrative household-level questions, + `run_economy_simulation` for society-wide reform analysis, and + `analyse_microdata` for allowed non-FRS microdata analysis. +- Use `validate_reform` when the user is drafting, debugging, or asking + whether parametric reform JSON is valid. Do not call it as a routine + preflight before every simulation; calculation tools validate internally. +- Use `run_python` as the fallback for structural reforms, parameter + introspection, historical lookups, novel aggregations, or cases the typed + tools cannot express. """ MODEL_INSTRUCTIONS_RULES = """ CRITICAL - START BY READING THE MODEL INSTRUCTIONS: -- At the start of a new line of analysis, use Python to inspect - `capabilities()`. +- When using `run_python` at the start of a new line of analysis, inspect + `capabilities()` first. - Use that to ground yourself in the available datasets, years, programmes, and caveats before you simulate. - If the user asks about something outside the modelled scope, say so clearly @@ -65,6 +72,10 @@ or real households. - Use aggregate microdata interfaces only for aggregate outputs; do not inspect or return individual survey rows as examples. +- `analyse_microdata` must not be used with FRS. For FRS, use aggregate outputs + such as `run_economy_simulation`. +- If `analyse_microdata` returns non-FRS sample records, describe them as + model records, not real households or actual survey rows. - If the user asks how individual households are constructed in the data, what households in the data look like, or for examples of actual household records, explain that this app cannot access or disclose real households. @@ -123,7 +134,7 @@ CHARTS: - When a visualisation would help (distributions, marginal-rate or tax-schedule curves, decile comparisons, trends), call the `generate_chart` tool after you - have the data from `run_python`. + have the data from a typed calculation tool or `run_python`. - The tool returns a `chart_markdown` field containing a ```chart fenced JSON block. Paste that block VERBATIM into your next text response - the frontend parses it to render the chart. If you do not include it, no chart will diff --git a/backend/routes/chatbot.py b/backend/routes/chatbot.py index 47ab22b..900394f 100644 --- a/backend/routes/chatbot.py +++ b/backend/routes/chatbot.py @@ -54,6 +54,7 @@ SUGGESTION_MODEL = os.environ.get("ANTHROPIC_SUGGESTION_MODEL", DEFAULT_FAST_MODEL) SUGGESTION_TIMEOUT_SECS = float(os.environ.get("ANTHROPIC_SUGGESTION_TIMEOUT_SECS", "5")) FAST_MODEL_MAX_INPUT_TOKENS = int(os.environ.get("ANTHROPIC_FAST_MODEL_MAX_INPUT_TOKENS", "120000")) +CHAT_TEMPERATURE = float(os.environ.get("ANTHROPIC_CHAT_TEMPERATURE", "0")) _REFERENCE_PATH = Path(__file__).resolve().parent.parent / "reference.md" try: @@ -372,6 +373,7 @@ async def generate_stream(): stream_kwargs: Dict[str, Any] = { "model": model, "max_tokens": 16000, + "temperature": CHAT_TEMPERATURE, "system": system_blocks, "messages": conversation, } @@ -495,7 +497,9 @@ async def generate_stream(): assistant_message["content"].append({"type": "tool_use", "id": tu["id"], "name": tu["name"], "input": tu["input"]}) conversation.append(assistant_message) - # Execute tools in parallel + # Execute tools in parallel and stream results as each finishes. + # The model-facing transcript below remains deterministic because + # it appends tool results in the original tool-call order. logger.info(f"[CHAT] Executing {len(tool_uses)} tools: {[t['name'] for t in tool_uses]}") async def execute_tool_async(tu): diff --git a/backend/tests/test_agent_tools.py b/backend/tests/test_agent_tools.py index 191d691..f8875bc 100644 --- a/backend/tests/test_agent_tools.py +++ b/backend/tests/test_agent_tools.py @@ -6,14 +6,21 @@ import importlib.util import os +from types import SimpleNamespace import pytest +import agent_tools +import tool_definitions from agent_tools import ( get_baseline_parameters, + validate_reform, calculate_household, + run_economy_simulation, + analyse_microdata, generate_chart, execute_tool, + TOOL_DEFINITIONS, _build_compiled_policy, _json_safe, run_python, @@ -26,6 +33,10 @@ ) +def _tool(name: str) -> dict: + return next(tool for tool in TOOL_DEFINITIONS if tool["name"] == name) + + # --------------------------------------------------------------------------- # policyengine_uk_compiled interface # --------------------------------------------------------------------------- @@ -252,6 +263,269 @@ def test_serialises_simulation_like_objects(self): assert serialised == {"result": {"child_benefit": 123}} +class TestToolDefinitions: + def test_exposes_typed_tools_and_fallback_python(self): + names = [tool["name"] for tool in TOOL_DEFINITIONS] + assert "validate_reform" in names + assert "calculate_household" in names + assert "run_economy_simulation" in names + assert "analyse_microdata" in names + assert "run_python" in names + + def test_tool_definitions_match_dispatch_handlers(self): + definition_names = [tool["name"] for tool in TOOL_DEFINITIONS] + assert len(definition_names) == len(set(definition_names)) + assert set(definition_names) == set(agent_tools.TOOL_HANDLERS) + + def test_agent_tools_reexports_canonical_tool_definitions(self): + assert TOOL_DEFINITIONS is tool_definitions.TOOL_DEFINITIONS + + def test_shared_schema_fragments_are_reused(self): + household_schema = _tool("calculate_household")["input_schema"] + economy_schema = _tool("run_economy_simulation")["input_schema"] + microdata_schema = _tool("analyse_microdata")["input_schema"] + chart_schema = _tool("generate_chart")["input_schema"] + + assert household_schema["properties"]["year"] is tool_definitions.YEAR_SCHEMA + assert economy_schema["properties"]["year"] is tool_definitions.YEAR_SCHEMA + assert microdata_schema["properties"]["year"] is tool_definitions.YEAR_SCHEMA + assert economy_schema["properties"]["reform"] is tool_definitions.REFORM_PROPERTY + assert microdata_schema["properties"]["reform"] is tool_definitions.REFORM_PROPERTY + assert microdata_schema["properties"]["columns"] is tool_definitions.STRING_ARRAY_SCHEMA + assert chart_schema["properties"]["x_format"]["enum"] == tool_definitions.CHART_FORMAT_SCHEMA["enum"] + + def test_validate_reform_tool_is_debugging_tool(self): + description = _tool("validate_reform")["description"] + assert "without running a simulation" in description + assert "drafting" in description + assert "routine preflight" in description + + def test_analyse_microdata_schema_excludes_frs(self): + dataset_schema = _tool("analyse_microdata")["input_schema"]["properties"]["dataset"] + assert dataset_schema["default"] == "efrs" + assert "frs" not in dataset_schema["enum"] + assert "efrs" in dataset_schema["enum"] + + def test_run_python_is_described_as_fallback(self): + description = _tool("run_python")["description"] + assert "fallback" in description.lower() + assert "calculate_household" in description + assert "run_economy_simulation" in description + assert "analyse_microdata" in description + + +class TestAnalyseMicrodataContract: + @pytest.mark.parametrize("dataset", ["frs", "FRS"]) + def test_rejects_frs_before_loading_microdata(self, monkeypatch, dataset): + def fail_if_called(*args, **kwargs): + raise AssertionError("FRS rejection must happen before loading or building simulations") + + monkeypatch.setattr(agent_tools, "_get_cached_microdata", fail_if_called) + monkeypatch.setattr(agent_tools, "_build_simulation", fail_if_called) + monkeypatch.setattr(agent_tools, "_build_compiled_policy", fail_if_called) + monkeypatch.setattr(agent_tools, "_build_structural_reform", fail_if_called) + + result = analyse_microdata(entity="households", operation="count", dataset=dataset) + assert "error" in result + assert "FRS" in result["error"] + + def _install_mock_microdata(self, monkeypatch, row_count=4): + import pandas as pd + + household_ids = list(range(row_count)) + regions = ["North" if i % 2 == 0 else "South" for i in household_ids] + weights = [2.0, 1.0, 3.0, 4.0] if row_count == 4 else [1.0] * row_count + incomes = [10.0, 30.0, 20.0, 40.0] if row_count == 4 else [float(i) for i in household_ids] + + households = pd.DataFrame( + { + "household_id": household_ids, + "region": regions, + "weight": weights, + "baseline_net_income": incomes, + "reform_net_income": [income + 5.0 for income in incomes], + "baseline_total_tax": [income / 10.0 for income in incomes], + "reform_total_tax": [income / 10.0 + 1.0 for income in incomes], + "baseline_total_benefits": [1.0 for _ in household_ids], + "reform_total_benefits": [2.0 for _ in household_ids], + } + ) + persons = pd.DataFrame( + { + "person_id": household_ids, + "household_id": household_ids, + "age": [30 + i for i in household_ids], + "gender": ["F" if i % 2 == 0 else "M" for i in household_ids], + "employment_income": incomes, + "baseline_income_tax": [income / 10.0 for income in incomes], + "reform_income_tax": [income / 10.0 + 1.0 for income in incomes], + "baseline_total_income": incomes, + "reform_total_income": [income + 5.0 for income in incomes], + } + ) + benunits = pd.DataFrame( + { + "benunit_id": household_ids, + "household_id": household_ids, + "baseline_total_benefits": [1.0 for _ in household_ids], + "reform_total_benefits": [2.0 for _ in household_ids], + } + ) + microdata = SimpleNamespace(persons=persons, benunits=benunits, households=households) + + calls = [] + + def get_cached_microdata(year, reform, dataset): + calls.append({"year": year, "reform": reform, "dataset": dataset}) + return microdata + + monkeypatch.setattr(agent_tools, "_get_cached_microdata", get_cached_microdata) + return calls + + def test_count_uses_weights_after_filtering(self, monkeypatch): + self._install_mock_microdata(monkeypatch) + + result = analyse_microdata( + entity="households", + operation="count", + filters={"region": "North"}, + dataset="efrs", + ) + + assert result["row_count"] == 2 + assert result["result"] == {"row_count": 2, "weighted_population": 5} + + def test_mean_uses_weights(self, monkeypatch): + self._install_mock_microdata(monkeypatch) + + result = analyse_microdata( + entity="households", + operation="mean", + columns=["baseline_net_income"], + dataset="efrs", + ) + + assert result["result"]["baseline_net_income"] == pytest.approx(27.0) + + def test_group_by_returns_weighted_group_means(self, monkeypatch): + self._install_mock_microdata(monkeypatch) + + result = analyse_microdata( + entity="households", + operation="group_by", + columns=["baseline_net_income"], + group_by=["region"], + dataset="efrs", + ) + + rows = {row["region"]: row for row in result["result"]} + assert rows["North"]["row_count"] == 2 + assert rows["North"]["weighted_population"] == 5.0 + assert rows["North"]["baseline_net_income"] == pytest.approx(16.0) + assert rows["South"]["row_count"] == 2 + assert rows["South"]["weighted_population"] == 5.0 + assert rows["South"]["baseline_net_income"] == pytest.approx(38.0) + + def test_sample_is_deterministic_and_capped_for_non_frs(self, monkeypatch): + self._install_mock_microdata(monkeypatch, row_count=25) + + kwargs = { + "entity": "households", + "operation": "sample", + "columns": ["household_id", "baseline_net_income"], + "n": 100, + "dataset": "efrs", + } + first = analyse_microdata(**kwargs) + second = analyse_microdata(**kwargs) + + assert len(first["result"]) == 20 + assert first["result"] == second["result"] + + +# --------------------------------------------------------------------------- +# validate_reform +# --------------------------------------------------------------------------- + +class TestValidateReform: + @pytest.fixture(autouse=True) + def mock_parameter_classes(self, monkeypatch): + import tooling.reforms as reform_helpers + + class DummyIncomeTaxParams: + model_fields = {"personal_allowance": None, "higher_rate": None} + + def __init__(self, **kwargs): + self.kwargs = kwargs + + monkeypatch.setattr( + reform_helpers, + "_parameter_classes", + lambda: ({"income_tax": DummyIncomeTaxParams}, object, object), + ) + + def test_valid_reform_returns_normalized_reform(self): + result = validate_reform({"income_tax": {"personal_allowance": 15000}}) + + assert result["valid"] is True + assert result["normalized_reform"] == {"income_tax": {"personal_allowance": 15000}} + assert result["programs"] == ["income_tax"] + assert result["warnings"] == [] + + def test_unknown_program_returns_json_error(self): + result = validate_reform({"not_real": {"field": 1}}) + + assert result["valid"] is False + assert result["errors"][0]["path"] == "not_real" + assert "Unknown reform program" in result["errors"][0]["message"] + assert result["valid_programs"] == ["income_tax"] + + def test_unknown_field_returns_json_error(self): + result = validate_reform({"income_tax": {"not_real_field": 1}}) + + assert result["valid"] is False + assert result["errors"][0]["path"] == "income_tax.not_real_field" + assert "Unknown field" in result["errors"][0]["message"] + + def test_null_fields_are_stripped(self): + result = validate_reform( + {"income_tax": {"personal_allowance": 15000, "higher_rate": None}} + ) + + assert result["valid"] is True + assert result["normalized_reform"] == {"income_tax": {"personal_allowance": 15000}} + + def test_simulation_and_validator_share_invalid_reform_rules(self): + reform = {"income_tax": {"not_real_field": 1}} + + validation = validate_reform(reform) + simulation = run_economy_simulation(reform=reform) + + assert validation["valid"] is False + assert "Unknown field" in validation["errors"][0]["message"] + assert "error" in simulation + assert "Unknown field" in simulation["error"] + + +@requires_compiled +class TestValidateReformCompiledPath: + def test_valid_reform_uses_compiled_parameter_classes(self): + result = validate_reform({"income_tax": {"personal_allowance": 15000}}) + + assert result["valid"] is True + assert result["normalized_reform"] == {"income_tax": {"personal_allowance": 15000}} + + def test_valid_programs_match_compiled_parameter_classes(self): + from tooling.reforms import _parameter_classes, get_valid_programs + + param_cls_map, _, _ = _parameter_classes() + + assert get_valid_programs() == list(param_cls_map) + result = validate_reform({"not_real": {"field": 1}}) + assert result["valid"] is False + assert result["valid_programs"] == list(param_cls_map) + + @requires_compiled class TestRunPython: def test_replaces_old_compute_sum_use_case(self): @@ -322,6 +596,26 @@ def test_compute_is_not_exposed(self): result = execute_tool("compute", {"operation": "sum", "data": [1, 2, 3]}) assert result["error"] == "Unknown tool: compute" + def test_dispatches_validate_reform(self, monkeypatch): + monkeypatch.setitem(agent_tools.TOOL_HANDLERS, "validate_reform", lambda **kwargs: {"tool": "validator", "input": kwargs}) + result = execute_tool("validate_reform", {"reform": {}}) + assert result["tool"] == "validator" + + def test_dispatches_calculate_household(self, monkeypatch): + monkeypatch.setitem(agent_tools.TOOL_HANDLERS, "calculate_household", lambda **kwargs: {"tool": "household", "input": kwargs}) + result = execute_tool("calculate_household", {"person": [], "benunit": [], "household": []}) + assert result["tool"] == "household" + + def test_dispatches_run_economy_simulation(self, monkeypatch): + monkeypatch.setitem(agent_tools.TOOL_HANDLERS, "run_economy_simulation", lambda **kwargs: {"tool": "economy", "input": kwargs}) + result = execute_tool("run_economy_simulation", {"year": 2025}) + assert result["tool"] == "economy" + + def test_dispatches_analyse_microdata(self, monkeypatch): + monkeypatch.setitem(agent_tools.TOOL_HANDLERS, "analyse_microdata", lambda **kwargs: {"tool": "microdata", "input": kwargs}) + result = execute_tool("analyse_microdata", {"entity": "households", "operation": "count"}) + assert result["tool"] == "microdata" + def test_dispatches_generate_chart(self): result = execute_tool("generate_chart", { "chart_type": "line", "title": "T", diff --git a/backend/tests/test_prompts.py b/backend/tests/test_prompts.py index 745491d..dfa12c4 100644 --- a/backend/tests/test_prompts.py +++ b/backend/tests/test_prompts.py @@ -36,6 +36,23 @@ def test_main_prompt_contains_microdata_privacy_rules(): assert "cannot access or disclose real households" in SYSTEM_PROMPT assert "illustrative synthetic households" in SYSTEM_PROMPT assert "Simulation.single_person()" in SYSTEM_PROMPT + assert "analyse_microdata` must not be used with FRS" in SYSTEM_PROMPT + + +def test_main_prompt_prefers_typed_tools_before_python(): + assert "calculate_household" in SYSTEM_PROMPT + assert "run_economy_simulation" in SYSTEM_PROMPT + assert "analyse_microdata" in SYSTEM_PROMPT + assert "validate_reform" in SYSTEM_PROMPT + assert "routine" in SYSTEM_PROMPT + assert "preflight" in SYSTEM_PROMPT + assert "Use `run_python` as the fallback" in SYSTEM_PROMPT + + +def test_validate_reform_tool_is_not_routine_preflight(): + description = _tool("validate_reform")["description"] + assert "without running a simulation" in description + assert "routine preflight" in description def test_run_python_tool_repeats_microdata_contract(): @@ -44,6 +61,16 @@ def test_run_python_tool_repeats_microdata_contract(): assert "illustrative synthetic households" in description assert "Simulation.single_person()" in description assert "rather than real households" in description + assert "fallback" in description.lower() + + +def test_analyse_microdata_tool_excludes_frs(): + tool = _tool("analyse_microdata") + description = tool["description"] + dataset_schema = tool["input_schema"]["properties"]["dataset"] + assert "does not support FRS" in description + assert dataset_schema["default"] == "efrs" + assert "frs" not in dataset_schema["enum"] def test_generate_chart_tool_requires_neutral_titles(): @@ -51,6 +78,7 @@ def test_generate_chart_tool_requires_neutral_titles(): description = chart_tool["description"] title_description = chart_tool["input_schema"]["properties"]["title"]["description"] assert "factually neutral" in description + assert "typed calculation tool or `run_python`" in description assert "factually neutral" in title_description.lower() diff --git a/backend/tests/test_structural_tools.py b/backend/tests/test_structural_tools.py index 2b2ae49..cc143cc 100644 --- a/backend/tests/test_structural_tools.py +++ b/backend/tests/test_structural_tools.py @@ -1,14 +1,25 @@ from pathlib import Path +import importlib.util +import os import sys from types import SimpleNamespace +import pytest + sys.path.insert(0, str(Path(__file__).resolve().parents[1])) sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "policyengine-uk-rust" / "interfaces" / "python")) import agent_tools from agent_tools import _build_structural_reform +COMPILED_AVAILABLE = importlib.util.find_spec("policyengine_uk_compiled") is not None +requires_compiled = pytest.mark.skipif( + os.environ.get("CI") != "true" and not COMPILED_AVAILABLE, + reason="policyengine_uk_compiled is not installed", +) + +@requires_compiled def test_build_structural_reform_pre_hook(): structural = _build_structural_reform( { @@ -34,6 +45,7 @@ def test_build_structural_reform_rejects_unknown_fields(): raise AssertionError("Expected ValueError for unknown structural_reform field") +@requires_compiled def test_run_economy_simulation_uses_true_baseline_hbai(monkeypatch): class DummyDump: def __init__(self, value): diff --git a/backend/tool_definitions.py b/backend/tool_definitions.py new file mode 100644 index 0000000..c7c5b40 --- /dev/null +++ b/backend/tool_definitions.py @@ -0,0 +1,267 @@ +"""Model-facing tool definitions for the UK chat runtime.""" + +from tooling.reforms import REFORM_SCHEMA + + +YEAR_SCHEMA = {"type": "integer", "default": 2025} + +REFORM_PROPERTY = REFORM_SCHEMA + +STRING_ARRAY_SCHEMA = {"type": "array", "items": {"type": "string"}} + +ALL_DATASET_SCHEMA = { + "type": "string", + "enum": ["frs", "efrs", "spi", "lcfs", "was"], + "default": "frs", + "description": "Microdata source for aggregate simulation. FRS is the default for aggregate outputs.", +} + +NON_FRS_DATASET_SCHEMA = { + "type": "string", + "enum": ["efrs", "spi", "lcfs", "was"], + "default": "efrs", + "description": "FRS is not available for analyse_microdata.", +} + +FILTERS_SCHEMA = { + "type": "object", + "description": ( + "Column to predicate map. Predicate can be a scalar, a list, or a " + "dict with min, max, gt, lt, gte, lte, or ne." + ), +} + +CHART_FORMAT_SCHEMA = { + "type": "string", + "enum": ["currency", "percent", "percent_decimal", "number", "compact", "year"], + "description": ( + "Number format for axis ticks and tooltips. Use `currency` for GBP " + "amounts, `percent` for values already on a 0-100 scale, " + "`percent_decimal` for 0-1 shares, `compact` for large counts (1.2k), " + "`year` for calendar years." + ), +} + +CHART_DATA_SCHEMA = { + "type": "array", + "description": "List of row objects. Each row must contain the `x_field` key and every key listed in `y_fields`.", + "items": {"type": "object"}, +} + + +VALIDATE_REFORM_INPUT_SCHEMA = { + "type": "object", + "properties": { + "reform": REFORM_PROPERTY, + }, + "required": ["reform"], +} + +HOUSEHOLD_RECORD_SCHEMA = {"type": "array", "items": {"type": "object"}} + +CALCULATE_HOUSEHOLD_INPUT_SCHEMA = { + "type": "object", + "properties": { + "person": { + **HOUSEHOLD_RECORD_SCHEMA, + "description": ( + "Person records. Each should include person_id, benunit_id, " + "household_id, and age. Common optional fields include " + "employment_income, self_employment_income, and pension_income." + ), + }, + "benunit": { + **HOUSEHOLD_RECORD_SCHEMA, + "description": "Benefit-unit records, each with benunit_id and household_id.", + }, + "household": { + **HOUSEHOLD_RECORD_SCHEMA, + "description": ( + "Household records, each with household_id. Add location fields " + "when relevant, for example region or is_in_scotland." + ), + }, + "year": YEAR_SCHEMA, + "reform": REFORM_PROPERTY, + }, + "required": ["person", "benunit", "household"], +} + +RUN_ECONOMY_SIMULATION_INPUT_SCHEMA = { + "type": "object", + "properties": { + "year": YEAR_SCHEMA, + "reform": REFORM_PROPERTY, + "dataset": ALL_DATASET_SCHEMA, + }, + "required": [], +} + +ANALYSE_MICRODATA_INPUT_SCHEMA = { + "type": "object", + "properties": { + "entity": {"type": "string", "enum": ["persons", "benunits", "households"]}, + "operation": {"type": "string", "enum": ["sample", "mean", "sum", "count", "group_by", "describe"]}, + "year": YEAR_SCHEMA, + "reform": REFORM_PROPERTY, + "filters": FILTERS_SCHEMA, + "columns": STRING_ARRAY_SCHEMA, + "group_by": STRING_ARRAY_SCHEMA, + "n": {"type": "integer", "default": 5, "description": "Sample size when operation is sample."}, + "dataset": NON_FRS_DATASET_SCHEMA, + }, + "required": ["entity", "operation"], +} + +RUN_PYTHON_INPUT_SCHEMA = { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": ( + "Python code to execute. Must assign the final answer to `result`. " + "Use the preloaded PolicyEngine interface directly, for example: " + "`sim = Simulation(year=2025)` or `policy = Parameters.model_validate({...})`." + ), + }, + }, + "required": ["code"], +} + +GENERATE_CHART_INPUT_SCHEMA = { + "type": "object", + "properties": { + "chart_type": { + "type": "string", + "enum": ["line", "bar", "area", "scatter"], + "description": ( + "Chart type. Use `line` for schedules/curves over a continuous x, " + "`bar` for category comparisons (e.g. deciles), `area` for " + "stacked compositions, `scatter` for point clouds." + ), + }, + "title": {"type": "string", "description": "Factually neutral chart title shown above the plot."}, + "data": CHART_DATA_SCHEMA, + "x_field": {"type": "string", "description": "Key in each data row to use as the x value."}, + "y_fields": { + **STRING_ARRAY_SCHEMA, + "description": ( + "Keys in each data row to plot as y series. Provide multiple for " + "multi-series charts (e.g. baseline vs reform)." + ), + }, + "x_label": {"type": "string", "description": "Axis label for x (defaults to `x_field`)."}, + "y_label": {"type": "string", "description": "Axis label for y (defaults to first y field or 'Value')."}, + "x_format": {**CHART_FORMAT_SCHEMA, "description": f"X-axis {CHART_FORMAT_SCHEMA['description']}"}, + "y_format": {**CHART_FORMAT_SCHEMA, "description": f"Y-axis {CHART_FORMAT_SCHEMA['description']}"}, + "x_min": {"type": "number", "description": "Optional fixed minimum for the x axis."}, + "x_max": {"type": "number", "description": "Optional fixed maximum for the x axis."}, + "y_min": {"type": "number", "description": "Optional fixed minimum for the y axis."}, + "y_max": {"type": "number", "description": "Optional fixed maximum for the y axis."}, + "series_labels": { + **STRING_ARRAY_SCHEMA, + "description": "Display labels for each y series, in the same order as `y_fields`.", + }, + "series_styles": { + "type": "array", + "description": "Line style per series (line/area charts).", + "items": {"type": "string", "enum": ["solid", "dashed", "dotted"]}, + }, + "series_curves": { + "type": "array", + "description": "Curve interpolation per series (line/area charts).", + "items": {"type": "string", "enum": ["smooth", "step", "linear"]}, + }, + "subtitle": {"type": "string", "description": "Optional subtitle shown under the title."}, + "source": {"type": "string", "description": "Optional source/caption shown beneath the chart."}, + "arrangement": { + "type": "string", + "enum": ["grouped", "stacked"], + "description": "For bar charts only: `grouped` side-by-side or `stacked`.", + }, + "area_fill": {"type": "boolean", "description": "For line charts only: fill the area under the line."}, + }, + "required": ["chart_type", "title", "data", "x_field", "y_fields"], +} + +TOOL_DEFINITIONS = [ + { + "name": "validate_reform", + "description": ( + "Validate parametric reform JSON without running a simulation. " + "Use this when the user is drafting, debugging, or asking whether " + "a reform object is valid. Do not call it as a routine preflight " + "before every simulation; calculation tools validate reforms internally." + ), + "input_schema": VALIDATE_REFORM_INPUT_SCHEMA, + }, + { + "name": "calculate_household", + "description": ( + "Compute taxes, benefits, and net income for an illustrative " + "specific household described with person, benefit-unit, and " + "household records. Prefer this over run_python for household-level " + "questions with a defined household composition. These inputs are " + "synthetic examples, not real households." + ), + "input_schema": CALCULATE_HOUSEHOLD_INPUT_SCHEMA, + }, + { + "name": "run_economy_simulation", + "description": ( + "Run a UK economy-wide microsimulation comparing baseline current " + "law to a parametric reform. Returns aggregate outputs including " + "budgetary impact, programme breakdown, decile impacts, " + "winners/losers, caseloads, HBAI incomes, and poverty metrics. " + "Prefer this over run_python for society-wide reform analysis. " + "Use run_python for structural reforms." + ), + "input_schema": RUN_ECONOMY_SIMULATION_INPUT_SCHEMA, + }, + { + "name": "analyse_microdata", + "description": ( + "Slice, filter, sample, or aggregate non-FRS model microdata for a " + "given year and optional parametric reform. Use this for allowed " + "non-FRS microdata follow-ups such as subset means, counts, group " + "breakdowns, descriptions, or small model-record samples. This tool " + "explicitly does not support FRS; use run_economy_simulation for " + "aggregate FRS outputs." + ), + "input_schema": ANALYSE_MICRODATA_INPUT_SCHEMA, + }, + { + "name": "run_python", + "description": ( + "Execute reproducible Python code using the official PolicyEngine UK compiled interface. " + "Prefer the typed tools (`calculate_household`, `run_economy_simulation`, `analyse_microdata`) " + "when the question fits their shape; use `run_python` as a fallback for structural reforms, " + "novel aggregations, parameter introspection, historical lookups, or unsupported cases. " + "The environment preloads `policyengine_uk_compiled` as `pe`, plus `Simulation`, `Parameters`, " + "`StructuralReform`, `aggregate_microdata`, `combine_microdata`, `capabilities`, " + "`ensure_dataset`, `pd`, `np`, `json`, and `math`. Assign the final answer to `result` and " + "use `print()` for intermediate output. Do not inspect or return row-level survey microdata, " + "including FRS data. For household examples, create illustrative synthetic households, prefer " + "`Simulation.single_person()` for single-person examples, and label them as illustrative rather " + "than real households." + ), + "input_schema": RUN_PYTHON_INPUT_SCHEMA, + }, + { + "name": "generate_chart", + "description": ( + "Generate a chart JSON block for the frontend to render. " + "Use this for visualisations such as income distributions, marginal-rate or tax-schedule curves, " + "decile impact comparisons, and trends over time or income. " + "Use factually neutral titles, subtitles, labels, and captions; do not call policies good, bad, fair, unfair, " + "regressive, progressive, generous, or punitive. " + "The tool returns a `chart_markdown` field containing a ```chart fenced JSON block - you MUST paste that " + "string verbatim into your next text response, otherwise the chart will not appear to the user. " + "Do not attempt to render charts with matplotlib inside `run_python`; the UI cannot display matplotlib output. " + "Compute the data first with a typed calculation tool or `run_python` " + "(returning a list of row dicts), then pass it to this tool." + ), + "input_schema": GENERATE_CHART_INPUT_SCHEMA, + }, +] + diff --git a/backend/tooling/__init__.py b/backend/tooling/__init__.py new file mode 100644 index 0000000..7af5084 --- /dev/null +++ b/backend/tooling/__init__.py @@ -0,0 +1,2 @@ +"""Shared deterministic helpers for UK chat tools.""" + diff --git a/backend/tooling/households.py b/backend/tooling/households.py new file mode 100644 index 0000000..1092c77 --- /dev/null +++ b/backend/tooling/households.py @@ -0,0 +1,87 @@ +"""Illustrative household input normalization.""" + +from typing import Any, Dict, List, Tuple + +from tooling.simulations import ensure_compiled_package_importable + + +def build_household_frames( + person: List[Dict[str, Any]], + benunit: List[Dict[str, Any]], + household: List[Dict[str, Any]], +) -> Tuple[Any, Any, Any]: + ensure_compiled_package_importable() + import pandas as pd + from policyengine_uk_compiled import BENUNIT_DEFAULTS, HOUSEHOLD_DEFAULTS, PERSON_DEFAULTS + + def fill_defaults(records, defaults): + return pd.DataFrame([{**defaults, **rec} for rec in records]) + + hh_id_map = {rec["household_id"]: i for i, rec in enumerate(household)} + bu_id_map = {rec["benunit_id"]: i for i, rec in enumerate(benunit)} + person = [ + { + **rec, + "person_id": i, + "benunit_id": bu_id_map[rec["benunit_id"]], + "household_id": hh_id_map[rec["household_id"]], + } + for i, rec in enumerate(person) + ] + benunit = [ + { + **rec, + "benunit_id": bu_id_map[rec["benunit_id"]], + "household_id": hh_id_map[rec["household_id"]], + } + for rec in benunit + ] + household = [{**rec, "household_id": hh_id_map[rec["household_id"]]} for rec in household] + + seen_bu_heads = set() + seen_hh_heads = set() + for rec in person: + bu_id = rec["benunit_id"] + hh_id = rec["household_id"] + is_adult = rec.get("age", 30) >= 16 + rec["is_benunit_head"] = is_adult and bu_id not in seen_bu_heads + rec["is_household_head"] = is_adult and hh_id not in seen_hh_heads + if rec["is_benunit_head"]: + seen_bu_heads.add(bu_id) + if rec["is_household_head"]: + seen_hh_heads.add(hh_id) + + persons_df = fill_defaults(person, PERSON_DEFAULTS) + benunits_df = fill_defaults(benunit, BENUNIT_DEFAULTS) + households_df = fill_defaults(household, HOUSEHOLD_DEFAULTS) + + if "person_ids" not in benunits_df.columns or ( + benunits_df["person_ids"] == BENUNIT_DEFAULTS.get("person_ids", 0) + ).all(): + bu_to_persons = persons_df.groupby("benunit_id")["person_id"].apply( + lambda ids: ",".join(str(i) for i in ids) + ) + benunits_df["person_ids"] = ( + benunits_df["benunit_id"].map(bu_to_persons).fillna(benunits_df["benunit_id"].astype(str)) + ) + if "benunit_ids" not in households_df.columns or ( + households_df["benunit_ids"] == HOUSEHOLD_DEFAULTS.get("benunit_ids", 0) + ).all(): + hh_to_benunits = benunits_df.groupby("household_id")["benunit_id"].apply( + lambda ids: ",".join(str(i) for i in ids) + ) + households_df["benunit_ids"] = ( + households_df["household_id"].map(hh_to_benunits).fillna(households_df["household_id"].astype(str)) + ) + if "person_ids" not in households_df.columns or ( + households_df["person_ids"] == HOUSEHOLD_DEFAULTS.get("person_ids", 0) + ).all(): + hh_to_persons = persons_df.groupby("household_id")["person_id"].apply( + lambda ids: ",".join(str(i) for i in ids) + ) + households_df["person_ids"] = ( + households_df["household_id"].map(hh_to_persons).fillna(households_df["household_id"].astype(str)) + ) + + return persons_df, benunits_df, households_df + diff --git a/backend/tooling/microdata.py b/backend/tooling/microdata.py new file mode 100644 index 0000000..761e945 --- /dev/null +++ b/backend/tooling/microdata.py @@ -0,0 +1,230 @@ +"""Microdata loading, filtering, and aggregate operations.""" + +import hashlib +import json +from typing import Any, Dict, List, Optional + +from tooling.reforms import build_compiled_policy +from tooling.serialization import json_safe +from tooling.simulations import DATASET_LABELS, build_simulation + + +_microdata_cache: Dict[tuple, Any] = {} +_MAX_CACHE = 4 + + +def hash_reform(reform: Optional[Dict[str, Any]]) -> str: + if not reform: + return "none" + return hashlib.md5(json.dumps(reform, sort_keys=True).encode()).hexdigest() + + +def get_cached_microdata(year: int, reform: Optional[Dict[str, Any]], dataset: str, structural=None): + """Return cached MicrodataResult. Structural reforms always run fresh.""" + if structural is not None: + policy = build_compiled_policy(reform) + sim = build_simulation(year, dataset) + return sim.run_microdata(policy=policy, structural=structural) + key = (year, hash_reform(reform), dataset) + if key not in _microdata_cache: + policy = build_compiled_policy(reform) + sim = build_simulation(year, dataset) + _microdata_cache[key] = sim.run_microdata(policy=policy) + if len(_microdata_cache) > _MAX_CACHE: + del _microdata_cache[next(iter(_microdata_cache))] + return _microdata_cache[key] + + +def analyse_microdata_result( + microdata, + entity: str, + operation: str, + year: int, + dataset_key: str, + reform_applied: bool, + structural_reform_applied: bool, + filters: Optional[Dict[str, Any]] = None, + columns: Optional[List[str]] = None, + group_by: Optional[List[str]] = None, + n: int = 5, +) -> Dict[str, Any]: + import pandas as pd + + entity_map = {"persons": microdata.persons, "benunits": microdata.benunits, "households": microdata.households} + if entity not in entity_map: + return {"error": "entity must be one of: persons, benunits, households"} + df = entity_map[entity].copy() + + weights = microdata.households[["household_id", "weight"]].copy() + if "household_id" in df.columns and "weight" not in df.columns: + df = df.merge(weights, on="household_id", how="left") + elif "weight" not in df.columns: + df["weight"] = 1.0 + + change_pairs = { + "persons": [ + ("income_tax", "baseline_income_tax", "reform_income_tax"), + ("employee_ni", "baseline_employee_ni", "reform_employee_ni"), + ("total_income", "baseline_total_income", "reform_total_income"), + ], + "benunits": [ + ("total_benefits", "baseline_total_benefits", "reform_total_benefits"), + ("universal_credit", "baseline_universal_credit", "reform_universal_credit"), + ("child_benefit", "baseline_child_benefit", "reform_child_benefit"), + ], + "households": [ + ("net_income", "baseline_net_income", "reform_net_income"), + ("total_tax", "baseline_total_tax", "reform_total_tax"), + ("total_benefits", "baseline_total_benefits", "reform_total_benefits"), + ], + } + for change_col, base_col, reform_col in change_pairs.get(entity, []): + if base_col in df.columns and reform_col in df.columns: + df[f"{change_col}_change"] = df[reform_col] - df[base_col] + + filters_applied = {} + if filters: + for col, fval in filters.items(): + if col not in df.columns: + return {"error": f"Column '{col}' not found. Available: {list(df.columns)}"} + filters_applied[col] = fval + if isinstance(fval, dict): + if "min" in fval: + df = df[df[col] >= fval["min"]] + if "max" in fval: + df = df[df[col] <= fval["max"]] + if "gt" in fval: + df = df[df[col] > fval["gt"]] + if "lt" in fval: + df = df[df[col] < fval["lt"]] + if "gte" in fval: + df = df[df[col] >= fval["gte"]] + if "lte" in fval: + df = df[df[col] <= fval["lte"]] + if "ne" in fval: + df = df[df[col] != fval["ne"]] + elif isinstance(fval, list): + df = df[df[col].isin(fval)] + else: + df = df[df[col] == fval] + + row_count = len(df) + weighted_count = int(df["weight"].sum()) if "weight" in df.columns else row_count + all_cols = list(df.columns) + + if columns: + missing = [c for c in columns if c not in df.columns] + if missing: + return {"error": f"Columns not found: {missing}. Available: {all_cols}"} + value_cols = columns + else: + if entity == "persons": + value_cols = [ + "age", + "gender", + "employment_income", + "self_employment_income", + "baseline_income_tax", + "reform_income_tax", + "income_tax_change", + "baseline_total_income", + "reform_total_income", + "total_income_change", + ] + elif entity == "benunits": + value_cols = [ + "baseline_total_benefits", + "reform_total_benefits", + "total_benefits_change", + "baseline_universal_credit", + "reform_universal_credit", + "baseline_child_benefit", + "reform_child_benefit", + ] + else: + value_cols = [ + "region", + "baseline_net_income", + "reform_net_income", + "net_income_change", + "baseline_total_tax", + "reform_total_tax", + "baseline_total_benefits", + "reform_total_benefits", + ] + value_cols = [c for c in value_cols if c in df.columns] + + if operation == "sample": + actual_n = min(n, 20, row_count) + sample_df = df[value_cols].sample(n=actual_n, random_state=42) if row_count >= actual_n else df[value_cols] + result = [ + {k: (None if (isinstance(v, float) and str(v) == "nan") else v) for k, v in row.items()} + for row in sample_df.to_dict(orient="records") + ] + elif operation == "mean": + numeric_cols = [c for c in value_cols if pd.api.types.is_numeric_dtype(df[c]) and c != "weight"] + result = { + c: float((df[c] * df["weight"]).sum() / df["weight"].sum()) + if df["weight"].sum() > 0 + else float(df[c].mean()) + for c in numeric_cols + } + elif operation == "sum": + numeric_cols = [c for c in value_cols if pd.api.types.is_numeric_dtype(df[c]) and c != "weight"] + result = {c: float((df[c] * df["weight"]).sum()) for c in numeric_cols} + elif operation == "count": + result = {"row_count": row_count, "weighted_population": weighted_count} + elif operation == "group_by": + if not group_by: + return {"error": "group_by operation requires at least one group_by column"} + missing_groups = [c for c in group_by if c not in df.columns] + if missing_groups: + return {"error": f"Group columns not found: {missing_groups}. Available: {all_cols}"} + numeric_cols = [c for c in value_cols if pd.api.types.is_numeric_dtype(df[c]) and c != "weight"] + grouped_rows = [] + for keys, group in df.groupby(group_by, dropna=False): + if not isinstance(keys, tuple): + keys = (keys,) + row = {col: json_safe(value) for col, value in zip(group_by, keys)} + row["row_count"] = int(len(group)) + row["weighted_population"] = float(group["weight"].sum()) + for col in numeric_cols: + row[col] = ( + float((group[col] * group["weight"]).sum() / group["weight"].sum()) + if group["weight"].sum() > 0 + else float(group[col].mean()) + ) + grouped_rows.append(row) + result = grouped_rows + elif operation == "describe": + numeric_cols = [c for c in value_cols if pd.api.types.is_numeric_dtype(df[c]) and c != "weight"] + result = { + c: { + "mean": float((df[c] * df["weight"]).sum() / df["weight"].sum()) + if df["weight"].sum() > 0 + else float(df[c].mean()), + "min": float(df[c].min()), + "max": float(df[c].max()), + "count": int(df[c].count()), + } + for c in numeric_cols + } + for col in [c for c in value_cols if not pd.api.types.is_numeric_dtype(df[c])]: + result[col] = {str(k): int(v) for k, v in df[col].value_counts().head(10).items()} + else: + return {"error": f"Unknown operation '{operation}'. Use: mean, sum, count, sample, group_by, describe"} + + return { + "entity": entity, + "operation": operation, + "year": year, + "dataset": DATASET_LABELS.get(dataset_key, dataset_key), + "reform_applied": reform_applied, + "structural_reform_applied": structural_reform_applied, + "filters_applied": filters_applied, + "row_count": row_count, + "weighted_count": weighted_count, + "result": result, + "available_columns": all_cols, + } + diff --git a/backend/tooling/reforms.py b/backend/tooling/reforms.py new file mode 100644 index 0000000..b0480cf --- /dev/null +++ b/backend/tooling/reforms.py @@ -0,0 +1,196 @@ +"""Parametric reform validation and compiled-policy construction.""" + +from typing import Any, Dict, List, Optional, Tuple + +from tooling.simulations import ensure_compiled_package_importable + + +DEFAULT_VALID_PROGRAMS = [ + "income_tax", + "national_insurance", + "universal_credit", + "child_benefit", + "state_pension", + "pension_credit", + "benefit_cap", + "housing_benefit", + "tax_credits", + "scottish_child_payment", + "stamp_duty", + "capital_gains_tax", + "wealth_tax", +] + +class ReformValidationError(ValueError): + """Validation error carrying JSON-friendly reform errors.""" + + def __init__(self, errors: List[Dict[str, str]]): + self.errors = errors + message = errors[0]["message"] if errors else "Invalid reform" + super().__init__(message) + + +def _parameter_classes(): + ensure_compiled_package_importable() + from policyengine_uk_compiled import ( + BenefitCapParams, + CapitalGainsTaxParams, + ChildBenefitParams, + HousingBenefitParams, + IncomeTaxParams, + NationalInsuranceParams, + PensionCreditParams, + ScottishChildPaymentParams, + StampDutyBand, + StampDutyParams, + StatePensionParams, + TaxCreditsParams, + UniversalCreditParams, + WealthTaxParams, + ) + + return ( + { + "income_tax": IncomeTaxParams, + "national_insurance": NationalInsuranceParams, + "universal_credit": UniversalCreditParams, + "child_benefit": ChildBenefitParams, + "state_pension": StatePensionParams, + "pension_credit": PensionCreditParams, + "benefit_cap": BenefitCapParams, + "housing_benefit": HousingBenefitParams, + "tax_credits": TaxCreditsParams, + "scottish_child_payment": ScottishChildPaymentParams, + "stamp_duty": StampDutyParams, + "capital_gains_tax": CapitalGainsTaxParams, + "wealth_tax": WealthTaxParams, + }, + StampDutyParams, + StampDutyBand, + ) + + +def get_valid_programs() -> List[str]: + try: + param_cls_map, _, _ = _parameter_classes() + except ModuleNotFoundError: + return DEFAULT_VALID_PROGRAMS + return list(param_cls_map) + + +def build_reform_schema(valid_programs: Optional[List[str]] = None) -> Dict[str, Any]: + programs = valid_programs or get_valid_programs() + return { + "type": "object", + "description": ( + "Parametric reform. Top-level keys are programmes; values are the " + "parameter changes for that programme. Valid programmes include " + f"{', '.join(programs[:-1])}, and {programs[-1]}. " + "Field names within each programme match the corresponding *Params " + "constructor. For structural reforms, use run_python instead." + ), + "additionalProperties": True, + } + + +REFORM_SCHEMA = build_reform_schema() + + +def normalise_reform( + reform: Optional[Dict[str, Any]], +) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]: + """Validate and normalize a reform dict, returning JSON and model objects.""" + if not reform: + return {}, {} + if not isinstance(reform, dict): + raise ReformValidationError( + [{"path": "reform", "message": f"Reform must be a dict, got {type(reform).__name__}"}] + ) + + param_cls_map, stamp_duty_cls, stamp_duty_band_cls = _parameter_classes() + normalized: Dict[str, Dict[str, Any]] = {} + model_kwargs: Dict[str, Any] = {} + errors: List[Dict[str, str]] = [] + + for program, fields in reform.items(): + if program not in param_cls_map: + errors.append( + { + "path": str(program), + "message": f"Unknown reform program '{program}'. Valid: {list(param_cls_map)}", + } + ) + continue + if not isinstance(fields, dict): + errors.append( + { + "path": str(program), + "message": f"Reform program '{program}' must be a dict, got {type(fields).__name__}", + } + ) + continue + + cls = param_cls_map[program] + valid_fields = set(cls.model_fields) + unknown = sorted(k for k in fields if k not in valid_fields and fields[k] is not None) + if unknown: + for field in unknown: + errors.append( + { + "path": f"{program}.{field}", + "message": ( + f"Unknown field(s) for '{program}': {unknown}. " + f"Valid: {sorted(valid_fields)}" + ), + } + ) + continue + + cleaned_fields = {k: v for k, v in fields.items() if v is not None} + model_fields = dict(cleaned_fields) + if cls is stamp_duty_cls and "bands" in model_fields: + model_fields["bands"] = [ + stamp_duty_band_cls(**band) if isinstance(band, dict) else band + for band in model_fields["bands"] + ] + try: + model_kwargs[program] = cls(**model_fields) + except Exception as exc: + errors.append({"path": str(program), "message": f"{type(exc).__name__}: {exc}"}) + continue + if cleaned_fields: + normalized[program] = cleaned_fields + + if errors: + raise ReformValidationError(errors) + return normalized, model_kwargs + + +def build_compiled_policy(reform: Optional[Dict[str, Any]]): + normalized, model_kwargs = normalise_reform(reform) + if not normalized: + return None + ensure_compiled_package_importable() + from policyengine_uk_compiled import Parameters + + return Parameters(**model_kwargs) + + +def validate_reform_dict(reform: Optional[Dict[str, Any]]) -> Dict[str, Any]: + try: + normalized, _ = normalise_reform(reform) + except ReformValidationError as exc: + return {"valid": False, "errors": exc.errors, "valid_programs": get_valid_programs()} + except Exception as exc: + return { + "valid": False, + "errors": [{"path": "reform", "message": f"{type(exc).__name__}: {exc}"}], + "valid_programs": get_valid_programs(), + } + + return { + "valid": True, + "normalized_reform": normalized, + "programs": list(normalized), + "warnings": [], + } diff --git a/backend/tooling/sandbox.py b/backend/tooling/sandbox.py new file mode 100644 index 0000000..c8b82b1 --- /dev/null +++ b/backend/tooling/sandbox.py @@ -0,0 +1,263 @@ +"""Restricted Python execution helpers used by chat tools.""" + +import builtins as _builtins +import json +import math +from typing import Any, Callable, Dict, List, Optional + +from tooling.serialization import json_safe +from tooling.simulations import ensure_compiled_package_importable + + +ALLOWED_IMPORT_ROOTS = {"json", "math", "numpy", "pandas"} + + +def safe_import(name, globals=None, locals=None, fromlist=(), level=0): + root_name = name.split(".")[0] + if root_name not in ALLOWED_IMPORT_ROOTS: + raise ImportError(f"Import of '{name}' is not allowed") + return __import__(name, globals, locals, fromlist, level) + + +def safe_builtins(names, print_func: Optional[Callable[..., None]] = None, allow_import: bool = False): + builtins = {name: getattr(_builtins, name) for name in names if hasattr(_builtins, name)} + if print_func is not None: + builtins["print"] = print_func + if allow_import: + builtins["__import__"] = safe_import + return builtins + + +def optional_numpy(): + try: + import numpy as np + except ImportError: + return None + return np + + +def compile_structural_hook(code: str): + """Compile a structural hook from code defining hook(...).""" + safe_names = ( + "range", + "len", + "int", + "float", + "str", + "bool", + "list", + "dict", + "tuple", + "set", + "zip", + "enumerate", + "map", + "filter", + "sorted", + "reversed", + "min", + "max", + "sum", + "abs", + "round", + "True", + "False", + "None", + "isinstance", + "ValueError", + "TypeError", + "print", + "any", + "all", + "pow", + "divmod", + ) + try: + import pandas as pd + except ImportError as exc: + raise ImportError("pandas is required for structural reform hooks") from exc + + allowed_globals: Dict[str, Any] = { + "__builtins__": safe_builtins(safe_names), + "math": math, + "json": json, + "pd": pd, + } + np = optional_numpy() + if np is not None: + allowed_globals["np"] = np + allowed_globals["numpy"] = np + + exec(code, allowed_globals) + hook = allowed_globals.get("hook") + if hook is None or not callable(hook): + raise ValueError("Structural hook code must define a callable `hook(year, persons, benunits, households)`") + return hook + + +def build_structural_reform(structural_reform: Optional[Dict[str, Any]]): + if not structural_reform: + return None + if not isinstance(structural_reform, dict): + raise ValueError(f"structural_reform must be a dict, got {type(structural_reform).__name__}") + + unknown = set(structural_reform) - {"pre", "post"} + if unknown: + raise ValueError(f"Unknown structural_reform field(s): {sorted(unknown)}. Valid: ['pre', 'post']") + + ensure_compiled_package_importable() + from policyengine_uk_compiled import StructuralReform + + pre = structural_reform.get("pre") + post = structural_reform.get("post") + if pre is not None and not isinstance(pre, str): + raise ValueError("structural_reform.pre must be a string of Python code defining hook(...)") + if post is not None and not isinstance(post, str): + raise ValueError("structural_reform.post must be a string of Python code defining hook(...)") + + return StructuralReform( + pre=compile_structural_hook(pre) if pre else None, + post=compile_structural_hook(post) if post else None, + ) + + +def run_python_code(code: str) -> Dict[str, Any]: + ensure_compiled_package_importable() + import pandas as pd + import policyengine_uk_compiled as pe + from policyengine_uk_compiled import ( + Parameters, + Simulation, + StructuralReform, + aggregate_microdata, + capabilities, + combine_microdata, + ensure_dataset, + ) + + safe_names = ( + "range", + "len", + "int", + "float", + "str", + "bool", + "list", + "dict", + "tuple", + "set", + "zip", + "enumerate", + "map", + "filter", + "sorted", + "reversed", + "min", + "max", + "sum", + "abs", + "round", + "True", + "False", + "None", + "isinstance", + "ValueError", + "TypeError", + "Exception", + "print", + "any", + "all", + "pow", + "divmod", + "complex", + "type", + "dir", + "hasattr", + "getattr", + ) + output_lines: List[str] = [] + + def safe_print(*args, **kwargs): + output_lines.append(" ".join(str(arg) for arg in args)) + + allowed_globals: Dict[str, Any] = { + "__builtins__": safe_builtins(safe_names, print_func=safe_print, allow_import=True), + "math": math, + "json": json, + "pd": pd, + "pe": pe, + "Simulation": Simulation, + "StructuralReform": StructuralReform, + "Parameters": Parameters, + "aggregate_microdata": aggregate_microdata, + "combine_microdata": combine_microdata, + "capabilities": capabilities, + "ensure_dataset": ensure_dataset, + } + np = optional_numpy() + if np is not None: + allowed_globals["np"] = np + allowed_globals["numpy"] = np + + try: + exec(code, allowed_globals) + except Exception as exc: + return {"error": f"{type(exc).__name__}: {exc}"} + + result = allowed_globals.get("result", None) + response: Dict[str, Any] = {} + if result is not None: + response["result"] = json_safe(result) + if output_lines: + response["output"] = "\n".join(output_lines) + if not response: + response["result"] = None + response["note"] = "No 'result' variable was set and nothing was printed." + return response + + +def run_generator(code: str) -> Dict[str, Any]: + """Execute a Python generator snippet that returns a dict of tool kwargs.""" + safe_names = ( + "range", + "len", + "int", + "float", + "str", + "bool", + "list", + "dict", + "tuple", + "set", + "zip", + "enumerate", + "map", + "filter", + "sorted", + "reversed", + "min", + "max", + "sum", + "abs", + "round", + "True", + "False", + "None", + "isinstance", + "ValueError", + "TypeError", + "append", + ) + allowed_globals: Dict[str, Any] = { + "__builtins__": safe_builtins(safe_names), + "math": math, + "json": json, + } + exec(code, allowed_globals) + if "generate" not in allowed_globals: + raise ValueError("Generator code must define a `generate()` function") + result = allowed_globals["generate"]() + if not isinstance(result, dict): + raise ValueError(f"generate() must return a dict, got {type(result).__name__}") + return result + diff --git a/backend/tooling/serialization.py b/backend/tooling/serialization.py new file mode 100644 index 0000000..ecff326 --- /dev/null +++ b/backend/tooling/serialization.py @@ -0,0 +1,96 @@ +"""Serialization helpers for tool outputs.""" + +from typing import Any, Dict, List + + +def json_safe(obj: Any) -> Any: + try: + import numpy as np + except ImportError: + np = None + + try: + import pandas as pd + except ImportError: + pd = None + + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj + if np is not None: + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.bool_): + return bool(obj) + if pd is not None: + if isinstance(obj, pd.DataFrame): + return obj.to_dict(orient="records") + if isinstance(obj, pd.Series): + return obj.to_list() + if isinstance(obj, dict): + return {str(k): json_safe(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple, set)): + return [json_safe(v) for v in obj] + if hasattr(obj, "model_dump") and callable(obj.model_dump): + return json_safe(obj.model_dump()) + if hasattr(obj, "dict") and callable(obj.dict): + return json_safe(obj.dict()) + try: + import dataclasses + + if dataclasses.is_dataclass(obj): + return json_safe(dataclasses.asdict(obj)) + except Exception: + pass + return str(obj) + + +def dataframe_to_records(df) -> List[Dict[str, Any]]: + return [ + { + key: ( + None + if (hasattr(value, "__class__") and value.__class__.__name__ == "float" and str(value) == "nan") + else value + ) + for key, value in row.items() + } + for row in df.to_dict(orient="records") + ] + + +def explore_tabular_data(data: List[Dict[str, Any]], max_unique_values: int = 20) -> Dict[str, Any]: + if not data or not isinstance(data[0], dict): + return {"error": "Data must be a non-empty list of dicts", "row_count": 0, "columns": []} + row_count = len(data) + all_keys = set() + for row in data: + all_keys.update(row.keys()) + columns = [] + for key in sorted(all_keys): + values = [row.get(key) for row in data] + sample_type = next((type(v).__name__ for v in values if v is not None), "unknown") + unique_values = list(set(v for v in values if v is not None)) + unique_count = len(unique_values) + col_info = { + "name": key, + "type": sample_type, + "unique_count": unique_count, + "null_count": sum(1 for v in values if v is None), + } + if unique_count <= max_unique_values: + try: + col_info["unique_values"] = sorted(unique_values) + except TypeError: + col_info["unique_values"] = unique_values + if sample_type in ("int", "float"): + numeric = [v for v in values if isinstance(v, (int, float))] + if numeric: + col_info["min"] = min(numeric) + col_info["max"] = max(numeric) + columns.append(col_info) + return {"row_count": row_count, "columns": columns} + diff --git a/backend/tooling/simulations.py b/backend/tooling/simulations.py new file mode 100644 index 0000000..813302f --- /dev/null +++ b/backend/tooling/simulations.py @@ -0,0 +1,60 @@ +"""PolicyEngine UK compiled-package and simulation helpers.""" + +from pathlib import Path +import sys +from typing import Any, Dict + + +DATASET_LABELS = { + "frs": "Family Resources Survey", + "efrs": "Enhanced FRS", + "spi": "Survey of Personal Incomes", + "lcfs": "Living Costs and Food Survey", + "was": "Wealth and Assets Survey", +} + + +def ensure_compiled_package_importable() -> None: + """Make the local policyengine_uk_compiled package importable in dev setups.""" + try: + import policyengine_uk_compiled # noqa: F401 + return + except ModuleNotFoundError: + pass + + repo_parent = Path(__file__).resolve().parents[3] + candidates = [ + repo_parent / "policyengine-uk-rust" / "interfaces" / "python", + repo_parent / "policyengine-uk-rust-codex-debug-issue" / "interfaces" / "python", + ] + for candidate in candidates: + if candidate.is_dir(): + candidate_str = str(candidate) + if candidate_str not in sys.path: + sys.path.insert(0, candidate_str) + try: + import policyengine_uk_compiled # noqa: F401 + return + except ModuleNotFoundError: + continue + + raise ModuleNotFoundError( + "policyengine_uk_compiled is not importable. Install the package or make sure a local " + "policyengine-uk-rust checkout with interfaces/python is available." + ) + + +def build_simulation(year: int, dataset: str = "frs"): + """Build a compiled PolicyEngine UK Simulation.""" + ensure_compiled_package_importable() + from policyengine_uk_compiled import Simulation + + return Simulation(year=year, dataset=dataset) + + +def get_capabilities() -> Dict[str, Any]: + ensure_compiled_package_importable() + from policyengine_uk_compiled import capabilities + + return capabilities() + diff --git a/docs/engineering/skills/uk-chat-runtime.md b/docs/engineering/skills/uk-chat-runtime.md index ceeb9a9..6381c7a 100644 --- a/docs/engineering/skills/uk-chat-runtime.md +++ b/docs/engineering/skills/uk-chat-runtime.md @@ -10,8 +10,12 @@ tools, calculation behavior, or AI-facing runtime boundaries. - `backend/routes/chatbot.py` owns application orchestration: request parsing, system block assembly, model calls, SSE streaming, tool-loop handling, usage/billing, title generation, and follow-up suggestions. -- `backend/agent_tools.py` owns deterministic tool implementations and model - tool schemas. +- `backend/agent_tools.py` owns the model-facing tool functions, dispatcher, + and compatibility exports. +- `backend/tool_definitions.py` owns model-facing tool schemas and + descriptions. Reuse shared schema fragments there rather than duplicating + object/array/dataset/format shapes. +- Shared deterministic tool helpers live under `backend/tooling/`. - `backend/scripts/build_reference.py` builds the API reference that is attached to the chat system prompt. @@ -34,33 +38,53 @@ belong in `backend/prompts.py`. Only tools listed in `TOOL_DEFINITIONS` and dispatched by `execute_tool()` are exposed to the model. At present, the exposed tools are: -- `run_python`: execute reproducible PolicyEngine UK Python code. +- `calculate_household`: calculate illustrative synthetic household outcomes. +- `validate_reform`: validate parametric reform JSON without running a + simulation. +- `run_economy_simulation`: calculate aggregate society-wide impacts for + parametric reforms. +- `analyse_microdata`: analyse allowed non-FRS model microdata through bounded + filtering, sampling, grouping, and aggregation operations. +- `run_python`: execute reproducible PolicyEngine UK Python code for fallback + cases that do not fit the typed tools. - `generate_chart`: return frontend-renderable chart JSON markdown. -Helper functions in `backend/agent_tools.py` are implementation details unless -they are added to both the tool definitions and dispatcher. +Helper functions in `backend/tooling/` are implementation details unless they +are added to both the tool definitions and dispatcher. ## Deterministic And Non-Deterministic Segments - Non-deterministic: user text interpretation, model planning, tool selection, prose generation, follow-up suggestions, and title generation. - Deterministic: request validation, plan-mode tool omission, tool dispatch, - Python execution, chart JSON construction, result truncation/summarisation, - billing calculation, and database writes. + typed tool execution after selection, Python execution, chart JSON + construction, result truncation/summarisation, billing calculation, and + database writes. Plan mode must remain structurally enforced by omitting tools from the model request, not only by prompting the model not to call tools. +Tool choice is model-mediated unless the route layer deliberately forces a +specific tool. Prompt and schema guidance improve selection consistency, but +they are not deterministic controls. The chat route defaults the model +temperature to `0` to reduce sampling variance. + ## Policy Analysis Rules - Be factually neutral. Do not call UK tax or benefit choices good, bad, fair, unfair, regressive, progressive, generous, punitive, or similar. -- Quantitative policy answers should be computed with `run_python`; do not +- Quantitative policy answers should be computed with the typed calculation + tools when they fit the request, or with `run_python` as a fallback; do not answer tax, benefit, reform, poverty, decile, or distributional questions from memory. +- Use `validate_reform` only when the user is drafting, debugging, or asking + whether reform JSON is valid. Do not use it as a routine preflight before + every simulation. - Do not access, display, quote, or imply access to row-level survey microdata or real households. - Use aggregate microdata interfaces only for aggregate outputs. +- Do not use `analyse_microdata` with FRS. For FRS-backed questions, use + aggregate outputs such as `run_economy_simulation`. - If a user asks for household examples, construct illustrative synthetic households with the public `Simulation` API. Prefer `Simulation.single_person()` for single-person examples, and label examples as diff --git a/frontend/src/app/ChatPage.tsx b/frontend/src/app/ChatPage.tsx index d3d8cb8..17cca06 100644 --- a/frontend/src/app/ChatPage.tsx +++ b/frontend/src/app/ChatPage.tsx @@ -1298,7 +1298,13 @@ export default function ChatPage() { > {t.status === "pending" && } {hasDetails && } - {t.tool_name === "run_python" ? "python" : t.tool_name} + {({ + run_python: "python", + calculate_household: "household sim", + validate_reform: "reform validation", + run_economy_simulation: "economy sim", + analyse_microdata: "microdata analysis", + } as Record)[t.tool_name] ?? t.tool_name} {t.status !== "pending" && } {isExpanded && hasDetails && renderToolDetails(t)}