diff --git a/scripts/llm_transport.py b/scripts/llm_transport.py index 0b775808ea..a554de44ba 100644 --- a/scripts/llm_transport.py +++ b/scripts/llm_transport.py @@ -44,7 +44,7 @@ # ── Anthropic API ──────────────────────────────────────────────────── ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" -ANTHROPIC_DEFAULT_MODEL = "claude-opus-4-20250514" +ANTHROPIC_DEFAULT_MODEL = "claude-opus-4-6" ANTHROPIC_API_VERSION = "2023-06-01" diff --git a/scripts/primer_classifier/classifier.py b/scripts/primer_classifier/classifier.py index 4b9102c645..a7c2b3281e 100644 --- a/scripts/primer_classifier/classifier.py +++ b/scripts/primer_classifier/classifier.py @@ -24,6 +24,7 @@ assign_verdict_with_llm, CategoryVerdict, classify_with_llm, + critique_reasoning, generate_suggestions, LLMError, ) @@ -696,13 +697,31 @@ def classify_project( base.method = "llm" return base - # Pass 2: assign verdict based on reasoning + # Pass 1.5: self-critique the reasoning for factual errors + try: + critiqued_reason, critiqued_categories = critique_reasoning( + llm_result.reason, + llm_result.categories, + errors_text, + source_context, + model, + ) + except LLMError as e: + print( + f" Warning: self-critique failed for {project.name}: {e}, " + "using original reasoning", + file=sys.stderr, + ) + critiqued_reason = llm_result.reason + critiqued_categories = llm_result.categories + + # Pass 2: assign verdict based on (critiqued) reasoning verdict, categories_with_verdicts = assign_verdict_with_llm( - llm_result.reason, llm_result.categories, model + critiqued_reason, critiqued_categories, model ) base.verdict = verdict - base.reason = llm_result.reason + base.reason = critiqued_reason base.method = "llm" base.categories = categories_with_verdicts base.pr_attribution = llm_result.pr_attribution @@ -722,6 +741,74 @@ def classify_project( return base +def _enforce_cross_project_consistency( + classifications: list[Classification], +) -> None: + """Enforce verdict consistency across projects that share error kinds. + + When multiple LLM-classified projects share the same error kind(s) and + have conflicting verdicts, the majority verdict wins. This prevents the + classifier from saying "overload resolution improved" for one project + and "overload resolution regressed" for another with the same pattern. + + Modifies classifications in place. + """ + # Only consider LLM-classified projects with clear verdicts + llm_classified = [ + c for c in classifications + if c.method == "llm" and c.verdict in ("regression", "improvement") + ] + if len(llm_classified) < 2: + return + + # Group projects by their error kinds (using frozenset for hashability) + kind_to_projects: dict[str, list[Classification]] = defaultdict(list) + for c in llm_classified: + for kind in c.error_kinds: + kind_to_projects[kind].append(c) + + # For each error kind shared by multiple projects, check consistency + already_adjusted: set[str] = set() + for kind, group in kind_to_projects.items(): + if len(group) < 2: + continue + + verdicts = [c.verdict for c in group] + if len(set(verdicts)) <= 1: + continue # already consistent + + # Count verdicts + verdict_counts: dict[str, int] = {} + for v in verdicts: + verdict_counts[v] = verdict_counts.get(v, 0) + 1 + + majority = max(verdict_counts, key=lambda v: verdict_counts[v]) + minority_count = sum( + c for v, c in verdict_counts.items() if v != majority + ) + + # Only enforce if majority is clear (> minority) + if verdict_counts[majority] <= minority_count: + continue + + # Update minority projects to match majority + adjusted_names = [] + for c in group: + if c.verdict != majority and c.project_name not in already_adjusted: + old = c.verdict + c.verdict = majority + adjusted_names.append(c.project_name) + already_adjusted.add(c.project_name) + + if adjusted_names: + print( + f" Cross-project consistency [{kind}]: " + f"{', '.join(adjusted_names)} adjusted to {majority} " + f"(vote: {verdict_counts})", + file=sys.stderr, + ) + + def classify_all( projects: list[ProjectDiff], fetch_code: bool = True, @@ -766,6 +853,11 @@ def classify_all( ) result.classifications.append(classification) + # Enforce cross-project consistency before counting verdicts + _enforce_cross_project_consistency(result.classifications) + + # Count verdicts after consistency enforcement + for classification in result.classifications: if classification.verdict == "regression": result.regressions += 1 elif classification.verdict == "improvement": diff --git a/scripts/primer_classifier/formatter.py b/scripts/primer_classifier/formatter.py index 092c17e97b..85def0fbe5 100644 --- a/scripts/primer_classifier/formatter.py +++ b/scripts/primer_classifier/formatter.py @@ -111,6 +111,42 @@ def func_replacer(match: re.Match) -> str: return result +def _format_reason(reason: str) -> str: + """Format a reason string for display, handling raw JSON dicts. + + When the LLM returns a JSON dict as the reason (with fields like + spec_check, runtime_behavior, etc.), format it into readable text + instead of dumping raw JSON. + """ + if not reason or not reason.strip().startswith("{"): + return reason + try: + parsed = json.loads(reason) + if not isinstance(parsed, dict): + return reason + except (json.JSONDecodeError, ValueError): + return reason + + # Format known analysis fields into readable text + _FIELD_LABELS = { + "spec_check": "Spec check", + "runtime_behavior": "Runtime behavior", + "mypy_pyright": "Mypy/pyright comparison", + "removal_assessment": "Removal assessment", + "pr_attribution": "PR attribution", + "reason": "Reasoning", + } + parts = [] + for key, label in _FIELD_LABELS.items(): + val = parsed.get(key) + if val and val != "N/A": + parts.append(f"**{label}:** {val}") + # Fall back to the "reason" field if nothing else was formatted + if not parts: + return parsed.get("reason", reason) + return "\n> ".join(parts) + + def _extract_root_cause(c) -> str: """Extract a linkified root cause string from a classification's pr_attribution. @@ -292,7 +328,7 @@ def format_markdown(result: ClassificationResult) -> str: ) lines.append("") else: - lines.append(f"> {c.reason}") + lines.append(f"> {_format_reason(c.reason)}") if c.pr_attribution and c.pr_attribution != "N/A": lines.append( f"> **Attribution:** " diff --git a/scripts/primer_classifier/llm_client.py b/scripts/primer_classifier/llm_client.py index ee1e26c721..be08e89726 100644 --- a/scripts/primer_classifier/llm_client.py +++ b/scripts/primer_classifier/llm_client.py @@ -335,7 +335,10 @@ def classify_with_llm( raw_response=result, ) - reason = classification.get("reason", "No reason provided") + reason_val = classification.get("reason", "No reason provided") + # The LLM sometimes returns a dict for "reason" instead of a string. + # If so, serialize it so downstream code always sees a string. + reason = json.dumps(reason_val) if isinstance(reason_val, dict) else str(reason_val) # Parse per-category reasoning (no verdicts in pass 1) categories: list[CategoryVerdict] = [] @@ -360,6 +363,113 @@ def classify_with_llm( ) +def _build_critique_system_prompt() -> str: + """Build the system prompt for pass 1.5: self-critique of reasoning.""" + return """You are reviewing reasoning about pyrefly type checker changes for factual accuracy. Pyrefly is a Python type checker. + +You will receive the original errors, source code context, and a prior analysis. Your job is to check the analysis for factual errors and correct them. Focus on: + +1. **Type system facts**: Verify all claims about what Python types support. Check whether types actually implement the protocols/methods the analysis claims they do or don't. Use your knowledge of the Python type system and standard library. +2. **Inheritance and class hierarchy**: Verify inheritance claims against the source code. Check whether attributes, methods, or protocols are inherited from parent classes that the analysis may have missed. +3. **Code behavior**: Verify that the analysis correctly describes what the source code does. Check variable types, return values, control flow, and assignments against the actual code provided. +4. **Generic type reasoning**: When the analysis reasons about TypeVars, generics, or parameterized types, verify that ALL constraints or bounds are checked, not just a subset. +5. **Logical consistency**: Check that the reasoning does not contradict itself. The conclusion must follow from the evidence presented. + +If you find factual errors, provide corrected reasoning. If the reasoning is factually correct, return it unchanged. + +Respond with JSON only: +{"corrected": true/false, "corrections": "description of what was wrong (empty string if nothing was wrong)", "reason": "the corrected reasoning (or original if no corrections needed)", "categories": [{"category": "short label", "reason": "corrected reasoning"}, ...]} + +The "categories" field is optional — omit it if there are no categories. When present, each entry should match a category from the original reasoning.""" + + +def _build_critique_prompt( + reason: str, + categories: list[CategoryVerdict], + errors_text: str, + source_context: Optional[str], +) -> str: + """Build the user prompt for pass 1.5: the reasoning to critique.""" + parts = [f"Original analysis to review:\n{reason}\n"] + if categories: + parts.append("Per-category reasoning:") + for cat in categories: + parts.append(f"- {cat.category}: {cat.reason}") + parts.append("") + parts.append(f"Original errors:\n{errors_text}\n") + if source_context: + parts.append(f"Source code context:\n{source_context}\n") + else: + parts.append("Source code: not available\n") + return "\n".join(parts) + + +def critique_reasoning( + reason: str, + categories: list[CategoryVerdict], + errors_text: str, + source_context: Optional[str] = None, + model: Optional[str] = None, +) -> tuple[str, list[CategoryVerdict]]: + """Pass 1.5: Self-critique the reasoning from Pass 1 for factual errors. + + Catches hallucinations like "dicts are not iterable" or incorrect + inheritance claims. Returns (corrected_reason, corrected_categories). + """ + backend, api_key = _get_backend() + if backend == "none": + raise LLMError( + "No API key found. Set LLAMA_API_KEY (Meta internal) " + "or CLASSIFIER_API_KEY / ANTHROPIC_API_KEY." + ) + + system_prompt = _build_critique_system_prompt() + user_prompt = _build_critique_prompt( + reason, categories, errors_text, source_context + ) + + print( + f"Using {backend} backend for self-critique (pass 1.5)", + file=sys.stderr, + ) + + if backend == "llama": + result = _call_llama_api(api_key, system_prompt, user_prompt, model) + else: + result = _call_anthropic_api(api_key, system_prompt, user_prompt, model) + + text = _extract_text_from_response(backend, result) + parsed = _parse_classification(text) + + corrected = parsed.get("corrected", False) + corrections = parsed.get("corrections", "") + if corrected and corrections: + print(f" Self-critique found errors: {corrections}", file=sys.stderr) + + corrected_reason = parsed.get("reason", reason) + + # Update category reasoning if corrections were provided + corrected_categories = list(categories) + if corrected: + cat_data_list = parsed.get("categories", []) + if cat_data_list: + cat_by_name = {c.get("category", ""): c for c in cat_data_list} + corrected_categories = [] + for cat in categories: + if cat.category in cat_by_name: + corrected_categories.append( + CategoryVerdict( + category=cat.category, + verdict="", + reason=cat_by_name[cat.category].get("reason", cat.reason), + ) + ) + else: + corrected_categories.append(cat) + + return corrected_reason, corrected_categories + + def _build_verdict_system_prompt() -> str: """Build the system prompt for pass 2: assigning a verdict from reasoning.""" return """You are assigning a verdict based on reasoning about pyrefly type checker changes. Pyrefly is a Python type checker. You are evaluating whether pyrefly got BETTER or WORSE. @@ -391,6 +501,25 @@ def _build_verdict_prompt(reason: str, categories: list[CategoryVerdict]) -> str return "\n".join(parts) +_VERDICT_VOTES = 5 # Number of verdict votes for majority voting + + +def _single_verdict_call( + backend: str, + api_key: str, + system_prompt: str, + user_prompt: str, + model: Optional[str], +) -> dict: + """Make a single verdict API call and return the parsed classification.""" + if backend == "llama": + result = _call_llama_api(api_key, system_prompt, user_prompt, model) + else: + result = _call_anthropic_api(api_key, system_prompt, user_prompt, model) + text = _extract_text_from_response(backend, result) + return _parse_classification(text) + + def assign_verdict_with_llm( reason: str, categories: list[CategoryVerdict], @@ -398,9 +527,12 @@ def assign_verdict_with_llm( ) -> tuple[str, list[CategoryVerdict]]: """Pass 2: Assign a verdict based on the reasoning from pass 1. - Makes a small, cheap API call (~500 tokens in, ~100 tokens out) that - reads the reasoning and assigns verdicts. Returns (overall_verdict, - categories_with_verdicts). + Uses majority voting: makes multiple cheap API calls (~100 tokens out + each) and takes the most common verdict. This reduces non-determinism + where the same reasoning could be classified as either "improvement" + or "regression" on different runs. + + Returns (overall_verdict, categories_with_verdicts). """ backend, api_key = _get_backend() if backend == "none": @@ -412,32 +544,66 @@ def assign_verdict_with_llm( system_prompt = _build_verdict_system_prompt() user_prompt = _build_verdict_prompt(reason, categories) - print(f"Using {backend} backend for verdict assignment (pass 2)", file=sys.stderr) + print( + f"Using {backend} backend for verdict assignment " + f"(pass 2, {_VERDICT_VOTES} votes)", + file=sys.stderr, + ) - if backend == "llama": - result = _call_llama_api(api_key, system_prompt, user_prompt, model) - else: - result = _call_anthropic_api(api_key, system_prompt, user_prompt, model) + # Collect multiple verdict votes + votes: list[dict] = [] + for i in range(_VERDICT_VOTES): + try: + parsed = _single_verdict_call( + backend, api_key, system_prompt, user_prompt, model + ) + votes.append(parsed) + except LLMError as e: + print( + f" Warning: verdict vote {i + 1}/{_VERDICT_VOTES} failed: {e}", + file=sys.stderr, + ) - text = _extract_text_from_response(backend, result) - parsed = _parse_classification(text) + if not votes: + raise LLMError("All verdict votes failed") + + # Count overall verdict votes + verdict_counts: dict[str, int] = {} + valid_verdicts = ("regression", "improvement", "neutral") + for parsed in votes: + v = parsed.get("verdict", "").lower().strip() + if v in valid_verdicts: + verdict_counts[v] = verdict_counts.get(v, 0) + 1 - verdict = parsed.get("verdict", "").lower().strip() - if verdict not in ("regression", "improvement", "neutral"): + if not verdict_counts: print( - f"Warning: verdict pass returned unexpected verdict '{verdict}', " - "treating as ambiguous", + "Warning: no valid verdicts from any vote, treating as ambiguous", file=sys.stderr, ) verdict = "neutral" - - # Merge per-category verdicts back into the category objects + else: + verdict = max(verdict_counts, key=lambda v: verdict_counts[v]) + + # Log vote distribution for transparency + vote_summary = ", ".join(f"{v}={c}" for v, c in sorted(verdict_counts.items())) + print(f" Verdict votes: {vote_summary} → {verdict}", file=sys.stderr) + + # Count per-category verdict votes across all votes + category_vote_counts: dict[str, dict[str, int]] = {} + for parsed in votes: + for cat_data in parsed.get("categories", []): + cat_name = cat_data.get("category", "") + cat_verdict = cat_data.get("verdict", "").lower().strip() + if cat_verdict in valid_verdicts: + if cat_name not in category_vote_counts: + category_vote_counts[cat_name] = {} + counts = category_vote_counts[cat_name] + counts[cat_verdict] = counts.get(cat_verdict, 0) + 1 + + # Pick majority verdict for each category verdict_by_category: dict[str, str] = {} - for cat_data in parsed.get("categories", []): - cat_verdict = cat_data.get("verdict", "").lower().strip() - if cat_verdict not in ("regression", "improvement", "neutral"): - cat_verdict = "neutral" - verdict_by_category[cat_data.get("category", "")] = cat_verdict + for cat_name, counts in category_vote_counts.items(): + verdict_by_category[cat_name] = max(counts, key=lambda v: counts[v]) updated_categories = [] for cat in categories: diff --git a/scripts/primer_classifier/test_classifier.py b/scripts/primer_classifier/test_classifier.py index e980ae4b1f..afa8c6e276 100644 --- a/scripts/primer_classifier/test_classifier.py +++ b/scripts/primer_classifier/test_classifier.py @@ -1000,7 +1000,7 @@ def test_pass1_returns_empty_verdict(self): assert result.categories[0].verdict == "" def test_assign_verdict_improvement(self): - """assign_verdict_with_llm should assign 'improvement' for false-positive reasoning.""" + """assign_verdict_with_llm should assign 'improvement' via majority vote.""" verdict_response = { "verdict": "improvement", "categories": [{"category": "missing-attr", "verdict": "improvement"}], @@ -1020,7 +1020,7 @@ def test_assign_verdict_improvement(self): assert updated_cats[0].reason == "false positives" def test_assign_verdict_regression(self): - """assign_verdict_with_llm should assign 'regression' for real-bug reasoning.""" + """assign_verdict_with_llm should assign 'regression' via majority vote.""" verdict_response = {"verdict": "regression"} with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}, clear=True): with patch( @@ -1034,7 +1034,7 @@ def test_assign_verdict_regression(self): assert verdict == "regression" def test_two_pass_end_to_end(self): - """Full two-pass flow: classify_project calls pass 1 then pass 2.""" + """Full multi-pass flow: classify_project calls pass 1, 1.5, then pass 2.""" pass1_response = { "reason": "These missing-attribute errors are false positives", "pr_attribution": "Change in solver.rs", @@ -1043,6 +1043,11 @@ def test_two_pass_end_to_end(self): "mypy_pyright": "N/A", "removal_assessment": "False positives", } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "These missing-attribute errors are false positives", + } pass2_response = {"verdict": "improvement"} p = ProjectDiff(name="test", removed=[self._make_entry()]) @@ -1051,22 +1056,28 @@ def test_two_pass_end_to_end(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ # Pass 1: reasoning {"content": [{"text": json.dumps(pass1_response)}]}, - # Pass 2: verdict - {"content": [{"text": json.dumps(pass2_response)}]}, + # Pass 1.5: self-critique + {"content": [{"text": json.dumps(critique_response)}]}, + # Pass 2: verdict (N votes) + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} + for _ in range(_VERDICT_VOTES) ] result = classify_project(p, fetch_code=False, use_llm=True) assert result.verdict == "improvement" assert result.reason == "These missing-attribute errors are false positives" assert result.pr_attribution == "Change in solver.rs" assert result.method == "llm" - # Verify both passes were called - assert mock_api.call_count == 2 + # Pass 1 + Pass 1.5 + Pass 2 (N votes) + assert mock_api.call_count == 2 + _VERDICT_VOTES def test_two_pass_with_categories(self): - """Two-pass flow with per-category verdicts.""" + """Multi-pass flow with per-category verdicts.""" pass1_response = { "reason": "Mixed results", "pr_attribution": "N/A", @@ -1075,6 +1086,15 @@ def test_two_pass_with_categories(self): {"category": "bad-return", "reason": "real type errors caught"}, ], } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "Mixed results", + "categories": [ + {"category": "missing-attr", "reason": "false positives from inheritance"}, + {"category": "bad-return", "reason": "real type errors caught"}, + ], + } pass2_response = { "verdict": "regression", "categories": [ @@ -1094,9 +1114,14 @@ def test_two_pass_with_categories(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ {"content": [{"text": json.dumps(pass1_response)}]}, - {"content": [{"text": json.dumps(pass2_response)}]}, + {"content": [{"text": json.dumps(critique_response)}]}, + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} + for _ in range(_VERDICT_VOTES) ] result = classify_project(p, fetch_code=False, use_llm=True) assert result.verdict == "regression" @@ -1112,6 +1137,11 @@ def test_two_pass_with_file_request(self): "reason": "After seeing source: false positives", "pr_attribution": "N/A", } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "After seeing source: false positives", + } pass2_response = {"verdict": "improvement"} p = ProjectDiff( @@ -1124,13 +1154,19 @@ def test_two_pass_with_file_request(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ # Pass 1, attempt 1: needs files {"content": [{"text": json.dumps(needs_files_response)}]}, # Pass 1, attempt 2 (with files): reasoning {"content": [{"text": json.dumps(pass1_response)}]}, - # Pass 2: verdict - {"content": [{"text": json.dumps(pass2_response)}]}, + # Pass 1.5: self-critique + {"content": [{"text": json.dumps(critique_response)}]}, + # Pass 2: verdict (N votes) + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} + for _ in range(_VERDICT_VOTES) ] with patch( "primer_classifier.classifier.fetch_files_by_path", @@ -1138,7 +1174,8 @@ def test_two_pass_with_file_request(self): ): result = classify_project(p, fetch_code=True, use_llm=True) assert result.verdict == "improvement" - assert mock_api.call_count == 3 + # Pass 1 (2 attempts) + Pass 1.5 + Pass 2 (N votes) + assert mock_api.call_count == 3 + _VERDICT_VOTES # --------------------------------------------------------------------------- @@ -1228,6 +1265,11 @@ def test_full_pipeline(self): "reason": "Variance check too broad", "pr_attribution": "Removed is_protocol() guard", } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "Variance check too broad", + } pass2_response = {"verdict": "regression"} pass3_response = { "summary": "Restore protocol guard", @@ -1248,9 +1290,15 @@ def test_full_pipeline(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ {"content": [{"text": json.dumps(pass1_response)}]}, # Pass 1 - {"content": [{"text": json.dumps(pass2_response)}]}, # Pass 2 + {"content": [{"text": json.dumps(critique_response)}]}, # Pass 1.5 + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} # Pass 2 + for _ in range(_VERDICT_VOTES) + ] + [ {"content": [{"text": json.dumps(pass3_response)}]}, # Pass 3 ] result = classify_all( @@ -1263,7 +1311,8 @@ def test_full_pipeline(self): assert result.suggestion is not None assert len(result.suggestion.suggestions) == 1 assert result.suggestion.suggestions[0].description == "Add is_protocol() check" - assert mock_api.call_count == 3 + # Pass 1 + Pass 1.5 + Pass 2 (N votes) + Pass 3 + assert mock_api.call_count == 2 + _VERDICT_VOTES + 1 class TestSuggestionInMarkdownOutput: