diff --git a/app/src/app/model/[id]/page.tsx b/app/src/app/model/[id]/page.tsx new file mode 100644 index 0000000..ec2ebd4 --- /dev/null +++ b/app/src/app/model/[id]/page.tsx @@ -0,0 +1,419 @@ +import { notFound } from "next/navigation"; +import type { Metadata } from "next"; +import Link from "next/link"; + +import SiteHeader from "../../../components/SiteHeader"; +import ProviderMark from "../../../components/ProviderMark"; +import { + MODEL_LABELS, + PROVIDER_LABELS, + getProviderForModel, +} from "../../../modelMeta"; +import { buildAllRows } from "../../../lib/sensitivity"; +import { metricTypeForVariable } from "../../../lib/scoring"; +import { + getVariableLabel, + VIEW_SHORT_LABELS, + type CountryCode, + type DashboardBundle, +} from "../../../types"; + +import rawDashboard from "../../../data.json"; + +const dashboard = rawDashboard as unknown as DashboardBundle; + +function getAllModelIds(): string[] { + const models = new Set(); + if (dashboard.global) { + for (const s of dashboard.global.modelStats) models.add(s.model); + } + for (const c of ["us", "uk"] as CountryCode[]) { + const d = dashboard.countries[c]; + if (d) for (const s of d.modelStats) models.add(s.model); + } + return [...models]; +} + +export function generateStaticParams() { + return getAllModelIds().map((id) => ({ id })); +} + +export async function generateMetadata({ + params, +}: { + params: Promise<{ id: string }>; +}): Promise { + const { id } = await params; + const label = MODEL_LABELS[id] ?? id; + return { + title: label, + description: `PolicyBench per-model analysis for ${label}: variable scores, hardest outputs, and sample wrong predictions.`, + }; +} + +function Badge({ score }: { score: number }) { + let cls = ""; + if (score >= 80) cls = "text-success-text bg-success-soft border-success/30"; + else if (score >= 65) + cls = "text-primary-strong bg-primary-soft border-primary/30"; + else if (score >= 50) + cls = "text-warning-text bg-warning-soft border-warning/40"; + else cls = "text-danger-text bg-danger-soft border-danger/40"; + return ( + + {score.toFixed(1)}% + + ); +} + +export default async function ModelPage({ + params, +}: { + params: Promise<{ id: string }>; +}) { + const { id: modelId } = await params; + + if (!getAllModelIds().includes(modelId)) notFound(); + + const modelLabel = MODEL_LABELS[modelId] ?? modelId; + const provider = getProviderForModel(modelId); + + // Global and country-level stats (no_tools condition) + const globalStat = dashboard.global?.modelStats.find( + (s) => s.model === modelId && s.condition === "no_tools", + ); + const usStat = dashboard.countries.us?.modelStats.find( + (s) => s.model === modelId && s.condition === "no_tools", + ); + const ukStat = dashboard.countries.uk?.modelStats.find( + (s) => s.model === modelId && s.condition === "no_tools", + ); + + // Global score and country scores — prefer countryScores on globalStat, + // fall back to per-country modelStat. + const globalScore = globalStat?.score; + const usScore = + (globalStat?.countryScores?.us ?? usStat?.score); + const ukScore = + (globalStat?.countryScores?.uk ?? ukStat?.score); + + // Parse coverage from global stat (nParsed / n), fall back to sum of countries. + let parseCov: number | null = null; + if (globalStat && globalStat.n > 0) { + parseCov = (globalStat.nParsed / globalStat.n) * 100; + } else { + const totalN = (usStat?.n ?? 0) + (ukStat?.n ?? 0); + const totalParsed = (usStat?.nParsed ?? 0) + (ukStat?.nParsed ?? 0); + if (totalN > 0) parseCov = (totalParsed / totalN) * 100; + } + + // --- Hardest output groups: (country, outputGroup) → mean score (0–100) --- + const allRows = buildAllRows(dashboard).filter((r) => r.model === modelId); + + const ogMap = new Map(); + for (const row of allRows) { + const key = `${row.country}|${row.outputGroup}`; + const c = ogMap.get(key) ?? { sum: 0, count: 0 }; + c.sum += row.score * 100; + c.count += 1; + ogMap.set(key, c); + } + const hardestVars = [...ogMap.entries()] + .map(([key, { sum, count }]) => { + const [country, outputGroup] = key.split("|") as [CountryCode, string]; + return { country, outputGroup, score: count > 0 ? sum / count : 0 }; + }) + .sort((a, b) => a.score - b.score) + .slice(0, 5); + + // --- Sample wrong predictions: relErr > 10%, score < 0.75 --- + type WrongPred = { + country: CountryCode; + scenarioId: string; + variable: string; + truth: number; + prediction: number; + score: number; + explanation?: string; + }; + const wrong: WrongPred[] = []; + const seen = new Set(); + + for (const country of ["us", "uk"] as CountryCode[]) { + const payload = dashboard.countries[country]; + if (!payload) continue; + for (const [scenarioId, varMap] of Object.entries( + payload.scenarioPredictions, + )) { + for (const [variable, modelMap] of Object.entries(varMap)) { + const rec = (modelMap as Record)[modelId]; + if (!rec || rec.prediction == null) continue; + const truth = rec.groundTruth; + const pred = rec.prediction as number; + const relErr = + truth !== 0 + ? Math.abs((pred - truth) / truth) + : Math.abs(pred) > 1 + ? 1 + : 0; + if (relErr <= 0.1) continue; + const rowScore = (rec.score ?? 0) as number; + if (rowScore >= 0.75) continue; + const key = `${country}|${scenarioId}|${variable}`; + if (seen.has(key)) continue; + seen.add(key); + wrong.push({ + country, + scenarioId, + variable, + truth, + prediction: pred, + score: rowScore, + explanation: rec.explanation as string | undefined, + }); + } + } + } + wrong.sort((a, b) => { + const ea = + a.truth !== 0 ? Math.abs((a.prediction - a.truth) / a.truth) : 1; + const eb = + b.truth !== 0 ? Math.abs((b.prediction - b.truth) / b.truth) : 1; + return eb - ea; + }); + const samples = wrong.slice(0, 10); + + // ---- Header expanded content ---- + const expanded = ( +
+
+ + + {modelLabel} + + {provider && ( + + {PROVIDER_LABELS[provider]} + + )} +
+
+ {globalScore !== undefined && ( +
+ + Global + + + {globalScore.toFixed(1)}% + +
+ )} + {usScore !== undefined && ( +
+ + US + + + {usScore.toFixed(1)}% + +
+ )} + {ukScore !== undefined && ( +
+ + UK + + + {ukScore.toFixed(1)}% + +
+ )} + {parseCov !== null && ( +
+ + Parse rate + + + {parseCov.toFixed(1)}% + +
+ )} +
+
+ ); + + return ( +
+

{modelLabel} — PolicyBench model deep-dive

+ + +
+ {/* Hardest output groups */} +
+
Hardest outputs
+

+ Top 5 lowest-scoring outputs +

+ {hardestVars.length === 0 ? ( +

+ No scored rows found for this model. +

+ ) : ( +
+ {hardestVars.map(({ country, outputGroup, score }, i) => ( +
+
+
+ + {i + 1} + +
+
+ {getVariableLabel(outputGroup, country)} +
+
+ {VIEW_SHORT_LABELS[country]} +
+
+
+ +
+
+ ))} +
+ )} +
+ + {/* Sample wrong predictions */} + {samples.length > 0 && ( +
+
Wrong predictions
+

+ Sample errors (>10% off) +

+

+ Cases where this model's prediction differed from the + PolicyEngine reference by more than 10%, sorted by largest + relative error. +

+
+ {samples.map( + ({ + country, + scenarioId, + variable, + truth, + prediction, + score, + explanation, + }) => { + const relErrPct = + truth !== 0 + ? Math.abs((prediction - truth) / truth) * 100 + : null; + const metricType = metricTypeForVariable(variable, country); + const fmt = (v: number) => + metricType === "amount" + ? `$${Math.round(v).toLocaleString()}` + : String(Math.round(v)); + return ( +
+
+ + {VIEW_SHORT_LABELS[country]} + + + {getVariableLabel(variable, country)} + + + + +
+
+
+
+ Prediction +
+
+ {fmt(prediction)} +
+
+
+
+ Ground truth +
+
+ {fmt(truth)} +
+
+ {relErrPct !== null && ( +
+
+ Error +
+
+ {relErrPct.toFixed(1)}% +
+
+ )} +
+ {explanation && ( +
+ + Model explanation + +

+ {explanation} +

+
+ )} +
+ + View in scenario explorer → + + · + + {scenarioId} + +
+
+ ); + }, + )} +
+
+ )} + +
+ + ← Back to leaderboard + +
+
+
+ ); +}