From c443c5f265b2d4a51bd72a4d8e9e5ba54a9dad2a Mon Sep 17 00:00:00 2001 From: Zeina Migeed Date: Fri, 20 Mar 2026 18:07:48 -0700 Subject: [PATCH 1/2] Fix raw JSON rendering in classifier verdict display Summary: I noticed when looking at the classifier output for https://github.com/facebook/pyrefly/pull/2764 that the "verdict" formatting needed to be fixed. Two fixes: 1. formatter.py: Add _format_reason() to render JSON reason dicts as labeled readable sections (e.g. "**Spec check:** ...", "**Reasoning:** ...") 2. llm_client.py: Ensure reason is always a string by serializing dict values, so downstream code handles it consistently. Reviewed By: grievejia Differential Revision: D97422229 --- scripts/primer_classifier/formatter.py | 38 ++++++++++++++++++++++++- scripts/primer_classifier/llm_client.py | 5 +++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/scripts/primer_classifier/formatter.py b/scripts/primer_classifier/formatter.py index 092c17e97b..85def0fbe5 100644 --- a/scripts/primer_classifier/formatter.py +++ b/scripts/primer_classifier/formatter.py @@ -111,6 +111,42 @@ def func_replacer(match: re.Match) -> str: return result +def _format_reason(reason: str) -> str: + """Format a reason string for display, handling raw JSON dicts. + + When the LLM returns a JSON dict as the reason (with fields like + spec_check, runtime_behavior, etc.), format it into readable text + instead of dumping raw JSON. + """ + if not reason or not reason.strip().startswith("{"): + return reason + try: + parsed = json.loads(reason) + if not isinstance(parsed, dict): + return reason + except (json.JSONDecodeError, ValueError): + return reason + + # Format known analysis fields into readable text + _FIELD_LABELS = { + "spec_check": "Spec check", + "runtime_behavior": "Runtime behavior", + "mypy_pyright": "Mypy/pyright comparison", + "removal_assessment": "Removal assessment", + "pr_attribution": "PR attribution", + "reason": "Reasoning", + } + parts = [] + for key, label in _FIELD_LABELS.items(): + val = parsed.get(key) + if val and val != "N/A": + parts.append(f"**{label}:** {val}") + # Fall back to the "reason" field if nothing else was formatted + if not parts: + return parsed.get("reason", reason) + return "\n> ".join(parts) + + def _extract_root_cause(c) -> str: """Extract a linkified root cause string from a classification's pr_attribution. @@ -292,7 +328,7 @@ def format_markdown(result: ClassificationResult) -> str: ) lines.append("") else: - lines.append(f"> {c.reason}") + lines.append(f"> {_format_reason(c.reason)}") if c.pr_attribution and c.pr_attribution != "N/A": lines.append( f"> **Attribution:** " diff --git a/scripts/primer_classifier/llm_client.py b/scripts/primer_classifier/llm_client.py index ee1e26c721..cd7ebc1240 100644 --- a/scripts/primer_classifier/llm_client.py +++ b/scripts/primer_classifier/llm_client.py @@ -335,7 +335,10 @@ def classify_with_llm( raw_response=result, ) - reason = classification.get("reason", "No reason provided") + reason_val = classification.get("reason", "No reason provided") + # The LLM sometimes returns a dict for "reason" instead of a string. + # If so, serialize it so downstream code always sees a string. + reason = json.dumps(reason_val) if isinstance(reason_val, dict) else str(reason_val) # Parse per-category reasoning (no verdicts in pass 1) categories: list[CategoryVerdict] = [] From 1690591dce59b0ea3da7df18d9f575c4ee17d986 Mon Sep 17 00:00:00 2001 From: Zeina Migeed Date: Fri, 20 Mar 2026 18:07:48 -0700 Subject: [PATCH 2/2] Improve classification quality with self-critique, majority voting, and cross-project consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The primer classifier has been producing inconsistent results across runs — the same primer diff can be classified as 'improvement' in one run and 'regression' in another. This was observed on real PRs like https://github.com/facebook/pyrefly/pull/2839 (altair TypeVar iterability) and https://github.com/facebook/pyrefly/pull/2764 (overload resolution, 60+ projects). Three changes to improve reliability: 1. **Self-critique pass (Pass 1.5)**: After Pass 1 produces reasoning, a new pass checks it for factual errors — e.g., claiming dicts are not iterable, incorrect inheritance claims, wrong TypeVar constraint analysis. This catches hallucinations before they reach the verdict pass. Tested on PR #2839 where it correctly identified that both constraints of `_C` (list and TypedDict) are iterable. 2. **Majority voting on verdict (Pass 2)**: Instead of a single verdict call, makes 5 independent calls and takes the majority. This reduces non-determinism where the same reasoning could be classified either way. Vote distribution is logged for transparency. 3. **Cross-project consistency enforcement**: After classifying all projects independently, groups them by error kind and enforces majority verdict within each group. This prevents the classifier from saying 'overload resolution improved' for one project and 'overload resolution regressed' for another with the same pattern. Also upgrades the default Anthropic model from claude-opus-4-20250514 to claude-opus-4-6 for better Pass 1 reasoning quality. Differential Revision: D97571454 --- scripts/llm_transport.py | 2 +- scripts/primer_classifier/classifier.py | 98 ++++++++- scripts/primer_classifier/llm_client.py | 205 +++++++++++++++++-- scripts/primer_classifier/test_classifier.py | 77 +++++-- 4 files changed, 343 insertions(+), 39 deletions(-) diff --git a/scripts/llm_transport.py b/scripts/llm_transport.py index 0b775808ea..a554de44ba 100644 --- a/scripts/llm_transport.py +++ b/scripts/llm_transport.py @@ -44,7 +44,7 @@ # ── Anthropic API ──────────────────────────────────────────────────── ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" -ANTHROPIC_DEFAULT_MODEL = "claude-opus-4-20250514" +ANTHROPIC_DEFAULT_MODEL = "claude-opus-4-6" ANTHROPIC_API_VERSION = "2023-06-01" diff --git a/scripts/primer_classifier/classifier.py b/scripts/primer_classifier/classifier.py index 4b9102c645..a7c2b3281e 100644 --- a/scripts/primer_classifier/classifier.py +++ b/scripts/primer_classifier/classifier.py @@ -24,6 +24,7 @@ assign_verdict_with_llm, CategoryVerdict, classify_with_llm, + critique_reasoning, generate_suggestions, LLMError, ) @@ -696,13 +697,31 @@ def classify_project( base.method = "llm" return base - # Pass 2: assign verdict based on reasoning + # Pass 1.5: self-critique the reasoning for factual errors + try: + critiqued_reason, critiqued_categories = critique_reasoning( + llm_result.reason, + llm_result.categories, + errors_text, + source_context, + model, + ) + except LLMError as e: + print( + f" Warning: self-critique failed for {project.name}: {e}, " + "using original reasoning", + file=sys.stderr, + ) + critiqued_reason = llm_result.reason + critiqued_categories = llm_result.categories + + # Pass 2: assign verdict based on (critiqued) reasoning verdict, categories_with_verdicts = assign_verdict_with_llm( - llm_result.reason, llm_result.categories, model + critiqued_reason, critiqued_categories, model ) base.verdict = verdict - base.reason = llm_result.reason + base.reason = critiqued_reason base.method = "llm" base.categories = categories_with_verdicts base.pr_attribution = llm_result.pr_attribution @@ -722,6 +741,74 @@ def classify_project( return base +def _enforce_cross_project_consistency( + classifications: list[Classification], +) -> None: + """Enforce verdict consistency across projects that share error kinds. + + When multiple LLM-classified projects share the same error kind(s) and + have conflicting verdicts, the majority verdict wins. This prevents the + classifier from saying "overload resolution improved" for one project + and "overload resolution regressed" for another with the same pattern. + + Modifies classifications in place. + """ + # Only consider LLM-classified projects with clear verdicts + llm_classified = [ + c for c in classifications + if c.method == "llm" and c.verdict in ("regression", "improvement") + ] + if len(llm_classified) < 2: + return + + # Group projects by their error kinds (using frozenset for hashability) + kind_to_projects: dict[str, list[Classification]] = defaultdict(list) + for c in llm_classified: + for kind in c.error_kinds: + kind_to_projects[kind].append(c) + + # For each error kind shared by multiple projects, check consistency + already_adjusted: set[str] = set() + for kind, group in kind_to_projects.items(): + if len(group) < 2: + continue + + verdicts = [c.verdict for c in group] + if len(set(verdicts)) <= 1: + continue # already consistent + + # Count verdicts + verdict_counts: dict[str, int] = {} + for v in verdicts: + verdict_counts[v] = verdict_counts.get(v, 0) + 1 + + majority = max(verdict_counts, key=lambda v: verdict_counts[v]) + minority_count = sum( + c for v, c in verdict_counts.items() if v != majority + ) + + # Only enforce if majority is clear (> minority) + if verdict_counts[majority] <= minority_count: + continue + + # Update minority projects to match majority + adjusted_names = [] + for c in group: + if c.verdict != majority and c.project_name not in already_adjusted: + old = c.verdict + c.verdict = majority + adjusted_names.append(c.project_name) + already_adjusted.add(c.project_name) + + if adjusted_names: + print( + f" Cross-project consistency [{kind}]: " + f"{', '.join(adjusted_names)} adjusted to {majority} " + f"(vote: {verdict_counts})", + file=sys.stderr, + ) + + def classify_all( projects: list[ProjectDiff], fetch_code: bool = True, @@ -766,6 +853,11 @@ def classify_all( ) result.classifications.append(classification) + # Enforce cross-project consistency before counting verdicts + _enforce_cross_project_consistency(result.classifications) + + # Count verdicts after consistency enforcement + for classification in result.classifications: if classification.verdict == "regression": result.regressions += 1 elif classification.verdict == "improvement": diff --git a/scripts/primer_classifier/llm_client.py b/scripts/primer_classifier/llm_client.py index cd7ebc1240..be08e89726 100644 --- a/scripts/primer_classifier/llm_client.py +++ b/scripts/primer_classifier/llm_client.py @@ -363,6 +363,113 @@ def classify_with_llm( ) +def _build_critique_system_prompt() -> str: + """Build the system prompt for pass 1.5: self-critique of reasoning.""" + return """You are reviewing reasoning about pyrefly type checker changes for factual accuracy. Pyrefly is a Python type checker. + +You will receive the original errors, source code context, and a prior analysis. Your job is to check the analysis for factual errors and correct them. Focus on: + +1. **Type system facts**: Verify all claims about what Python types support. Check whether types actually implement the protocols/methods the analysis claims they do or don't. Use your knowledge of the Python type system and standard library. +2. **Inheritance and class hierarchy**: Verify inheritance claims against the source code. Check whether attributes, methods, or protocols are inherited from parent classes that the analysis may have missed. +3. **Code behavior**: Verify that the analysis correctly describes what the source code does. Check variable types, return values, control flow, and assignments against the actual code provided. +4. **Generic type reasoning**: When the analysis reasons about TypeVars, generics, or parameterized types, verify that ALL constraints or bounds are checked, not just a subset. +5. **Logical consistency**: Check that the reasoning does not contradict itself. The conclusion must follow from the evidence presented. + +If you find factual errors, provide corrected reasoning. If the reasoning is factually correct, return it unchanged. + +Respond with JSON only: +{"corrected": true/false, "corrections": "description of what was wrong (empty string if nothing was wrong)", "reason": "the corrected reasoning (or original if no corrections needed)", "categories": [{"category": "short label", "reason": "corrected reasoning"}, ...]} + +The "categories" field is optional — omit it if there are no categories. When present, each entry should match a category from the original reasoning.""" + + +def _build_critique_prompt( + reason: str, + categories: list[CategoryVerdict], + errors_text: str, + source_context: Optional[str], +) -> str: + """Build the user prompt for pass 1.5: the reasoning to critique.""" + parts = [f"Original analysis to review:\n{reason}\n"] + if categories: + parts.append("Per-category reasoning:") + for cat in categories: + parts.append(f"- {cat.category}: {cat.reason}") + parts.append("") + parts.append(f"Original errors:\n{errors_text}\n") + if source_context: + parts.append(f"Source code context:\n{source_context}\n") + else: + parts.append("Source code: not available\n") + return "\n".join(parts) + + +def critique_reasoning( + reason: str, + categories: list[CategoryVerdict], + errors_text: str, + source_context: Optional[str] = None, + model: Optional[str] = None, +) -> tuple[str, list[CategoryVerdict]]: + """Pass 1.5: Self-critique the reasoning from Pass 1 for factual errors. + + Catches hallucinations like "dicts are not iterable" or incorrect + inheritance claims. Returns (corrected_reason, corrected_categories). + """ + backend, api_key = _get_backend() + if backend == "none": + raise LLMError( + "No API key found. Set LLAMA_API_KEY (Meta internal) " + "or CLASSIFIER_API_KEY / ANTHROPIC_API_KEY." + ) + + system_prompt = _build_critique_system_prompt() + user_prompt = _build_critique_prompt( + reason, categories, errors_text, source_context + ) + + print( + f"Using {backend} backend for self-critique (pass 1.5)", + file=sys.stderr, + ) + + if backend == "llama": + result = _call_llama_api(api_key, system_prompt, user_prompt, model) + else: + result = _call_anthropic_api(api_key, system_prompt, user_prompt, model) + + text = _extract_text_from_response(backend, result) + parsed = _parse_classification(text) + + corrected = parsed.get("corrected", False) + corrections = parsed.get("corrections", "") + if corrected and corrections: + print(f" Self-critique found errors: {corrections}", file=sys.stderr) + + corrected_reason = parsed.get("reason", reason) + + # Update category reasoning if corrections were provided + corrected_categories = list(categories) + if corrected: + cat_data_list = parsed.get("categories", []) + if cat_data_list: + cat_by_name = {c.get("category", ""): c for c in cat_data_list} + corrected_categories = [] + for cat in categories: + if cat.category in cat_by_name: + corrected_categories.append( + CategoryVerdict( + category=cat.category, + verdict="", + reason=cat_by_name[cat.category].get("reason", cat.reason), + ) + ) + else: + corrected_categories.append(cat) + + return corrected_reason, corrected_categories + + def _build_verdict_system_prompt() -> str: """Build the system prompt for pass 2: assigning a verdict from reasoning.""" return """You are assigning a verdict based on reasoning about pyrefly type checker changes. Pyrefly is a Python type checker. You are evaluating whether pyrefly got BETTER or WORSE. @@ -394,6 +501,25 @@ def _build_verdict_prompt(reason: str, categories: list[CategoryVerdict]) -> str return "\n".join(parts) +_VERDICT_VOTES = 5 # Number of verdict votes for majority voting + + +def _single_verdict_call( + backend: str, + api_key: str, + system_prompt: str, + user_prompt: str, + model: Optional[str], +) -> dict: + """Make a single verdict API call and return the parsed classification.""" + if backend == "llama": + result = _call_llama_api(api_key, system_prompt, user_prompt, model) + else: + result = _call_anthropic_api(api_key, system_prompt, user_prompt, model) + text = _extract_text_from_response(backend, result) + return _parse_classification(text) + + def assign_verdict_with_llm( reason: str, categories: list[CategoryVerdict], @@ -401,9 +527,12 @@ def assign_verdict_with_llm( ) -> tuple[str, list[CategoryVerdict]]: """Pass 2: Assign a verdict based on the reasoning from pass 1. - Makes a small, cheap API call (~500 tokens in, ~100 tokens out) that - reads the reasoning and assigns verdicts. Returns (overall_verdict, - categories_with_verdicts). + Uses majority voting: makes multiple cheap API calls (~100 tokens out + each) and takes the most common verdict. This reduces non-determinism + where the same reasoning could be classified as either "improvement" + or "regression" on different runs. + + Returns (overall_verdict, categories_with_verdicts). """ backend, api_key = _get_backend() if backend == "none": @@ -415,32 +544,66 @@ def assign_verdict_with_llm( system_prompt = _build_verdict_system_prompt() user_prompt = _build_verdict_prompt(reason, categories) - print(f"Using {backend} backend for verdict assignment (pass 2)", file=sys.stderr) + print( + f"Using {backend} backend for verdict assignment " + f"(pass 2, {_VERDICT_VOTES} votes)", + file=sys.stderr, + ) - if backend == "llama": - result = _call_llama_api(api_key, system_prompt, user_prompt, model) - else: - result = _call_anthropic_api(api_key, system_prompt, user_prompt, model) + # Collect multiple verdict votes + votes: list[dict] = [] + for i in range(_VERDICT_VOTES): + try: + parsed = _single_verdict_call( + backend, api_key, system_prompt, user_prompt, model + ) + votes.append(parsed) + except LLMError as e: + print( + f" Warning: verdict vote {i + 1}/{_VERDICT_VOTES} failed: {e}", + file=sys.stderr, + ) - text = _extract_text_from_response(backend, result) - parsed = _parse_classification(text) + if not votes: + raise LLMError("All verdict votes failed") + + # Count overall verdict votes + verdict_counts: dict[str, int] = {} + valid_verdicts = ("regression", "improvement", "neutral") + for parsed in votes: + v = parsed.get("verdict", "").lower().strip() + if v in valid_verdicts: + verdict_counts[v] = verdict_counts.get(v, 0) + 1 - verdict = parsed.get("verdict", "").lower().strip() - if verdict not in ("regression", "improvement", "neutral"): + if not verdict_counts: print( - f"Warning: verdict pass returned unexpected verdict '{verdict}', " - "treating as ambiguous", + "Warning: no valid verdicts from any vote, treating as ambiguous", file=sys.stderr, ) verdict = "neutral" - - # Merge per-category verdicts back into the category objects + else: + verdict = max(verdict_counts, key=lambda v: verdict_counts[v]) + + # Log vote distribution for transparency + vote_summary = ", ".join(f"{v}={c}" for v, c in sorted(verdict_counts.items())) + print(f" Verdict votes: {vote_summary} → {verdict}", file=sys.stderr) + + # Count per-category verdict votes across all votes + category_vote_counts: dict[str, dict[str, int]] = {} + for parsed in votes: + for cat_data in parsed.get("categories", []): + cat_name = cat_data.get("category", "") + cat_verdict = cat_data.get("verdict", "").lower().strip() + if cat_verdict in valid_verdicts: + if cat_name not in category_vote_counts: + category_vote_counts[cat_name] = {} + counts = category_vote_counts[cat_name] + counts[cat_verdict] = counts.get(cat_verdict, 0) + 1 + + # Pick majority verdict for each category verdict_by_category: dict[str, str] = {} - for cat_data in parsed.get("categories", []): - cat_verdict = cat_data.get("verdict", "").lower().strip() - if cat_verdict not in ("regression", "improvement", "neutral"): - cat_verdict = "neutral" - verdict_by_category[cat_data.get("category", "")] = cat_verdict + for cat_name, counts in category_vote_counts.items(): + verdict_by_category[cat_name] = max(counts, key=lambda v: counts[v]) updated_categories = [] for cat in categories: diff --git a/scripts/primer_classifier/test_classifier.py b/scripts/primer_classifier/test_classifier.py index e980ae4b1f..afa8c6e276 100644 --- a/scripts/primer_classifier/test_classifier.py +++ b/scripts/primer_classifier/test_classifier.py @@ -1000,7 +1000,7 @@ def test_pass1_returns_empty_verdict(self): assert result.categories[0].verdict == "" def test_assign_verdict_improvement(self): - """assign_verdict_with_llm should assign 'improvement' for false-positive reasoning.""" + """assign_verdict_with_llm should assign 'improvement' via majority vote.""" verdict_response = { "verdict": "improvement", "categories": [{"category": "missing-attr", "verdict": "improvement"}], @@ -1020,7 +1020,7 @@ def test_assign_verdict_improvement(self): assert updated_cats[0].reason == "false positives" def test_assign_verdict_regression(self): - """assign_verdict_with_llm should assign 'regression' for real-bug reasoning.""" + """assign_verdict_with_llm should assign 'regression' via majority vote.""" verdict_response = {"verdict": "regression"} with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}, clear=True): with patch( @@ -1034,7 +1034,7 @@ def test_assign_verdict_regression(self): assert verdict == "regression" def test_two_pass_end_to_end(self): - """Full two-pass flow: classify_project calls pass 1 then pass 2.""" + """Full multi-pass flow: classify_project calls pass 1, 1.5, then pass 2.""" pass1_response = { "reason": "These missing-attribute errors are false positives", "pr_attribution": "Change in solver.rs", @@ -1043,6 +1043,11 @@ def test_two_pass_end_to_end(self): "mypy_pyright": "N/A", "removal_assessment": "False positives", } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "These missing-attribute errors are false positives", + } pass2_response = {"verdict": "improvement"} p = ProjectDiff(name="test", removed=[self._make_entry()]) @@ -1051,22 +1056,28 @@ def test_two_pass_end_to_end(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ # Pass 1: reasoning {"content": [{"text": json.dumps(pass1_response)}]}, - # Pass 2: verdict - {"content": [{"text": json.dumps(pass2_response)}]}, + # Pass 1.5: self-critique + {"content": [{"text": json.dumps(critique_response)}]}, + # Pass 2: verdict (N votes) + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} + for _ in range(_VERDICT_VOTES) ] result = classify_project(p, fetch_code=False, use_llm=True) assert result.verdict == "improvement" assert result.reason == "These missing-attribute errors are false positives" assert result.pr_attribution == "Change in solver.rs" assert result.method == "llm" - # Verify both passes were called - assert mock_api.call_count == 2 + # Pass 1 + Pass 1.5 + Pass 2 (N votes) + assert mock_api.call_count == 2 + _VERDICT_VOTES def test_two_pass_with_categories(self): - """Two-pass flow with per-category verdicts.""" + """Multi-pass flow with per-category verdicts.""" pass1_response = { "reason": "Mixed results", "pr_attribution": "N/A", @@ -1075,6 +1086,15 @@ def test_two_pass_with_categories(self): {"category": "bad-return", "reason": "real type errors caught"}, ], } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "Mixed results", + "categories": [ + {"category": "missing-attr", "reason": "false positives from inheritance"}, + {"category": "bad-return", "reason": "real type errors caught"}, + ], + } pass2_response = { "verdict": "regression", "categories": [ @@ -1094,9 +1114,14 @@ def test_two_pass_with_categories(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ {"content": [{"text": json.dumps(pass1_response)}]}, - {"content": [{"text": json.dumps(pass2_response)}]}, + {"content": [{"text": json.dumps(critique_response)}]}, + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} + for _ in range(_VERDICT_VOTES) ] result = classify_project(p, fetch_code=False, use_llm=True) assert result.verdict == "regression" @@ -1112,6 +1137,11 @@ def test_two_pass_with_file_request(self): "reason": "After seeing source: false positives", "pr_attribution": "N/A", } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "After seeing source: false positives", + } pass2_response = {"verdict": "improvement"} p = ProjectDiff( @@ -1124,13 +1154,19 @@ def test_two_pass_with_file_request(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ # Pass 1, attempt 1: needs files {"content": [{"text": json.dumps(needs_files_response)}]}, # Pass 1, attempt 2 (with files): reasoning {"content": [{"text": json.dumps(pass1_response)}]}, - # Pass 2: verdict - {"content": [{"text": json.dumps(pass2_response)}]}, + # Pass 1.5: self-critique + {"content": [{"text": json.dumps(critique_response)}]}, + # Pass 2: verdict (N votes) + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} + for _ in range(_VERDICT_VOTES) ] with patch( "primer_classifier.classifier.fetch_files_by_path", @@ -1138,7 +1174,8 @@ def test_two_pass_with_file_request(self): ): result = classify_project(p, fetch_code=True, use_llm=True) assert result.verdict == "improvement" - assert mock_api.call_count == 3 + # Pass 1 (2 attempts) + Pass 1.5 + Pass 2 (N votes) + assert mock_api.call_count == 3 + _VERDICT_VOTES # --------------------------------------------------------------------------- @@ -1228,6 +1265,11 @@ def test_full_pipeline(self): "reason": "Variance check too broad", "pr_attribution": "Removed is_protocol() guard", } + critique_response = { + "corrected": False, + "corrections": "", + "reason": "Variance check too broad", + } pass2_response = {"verdict": "regression"} pass3_response = { "summary": "Restore protocol guard", @@ -1248,9 +1290,15 @@ def test_full_pipeline(self): with patch( "primer_classifier.llm_client._call_anthropic_api", ) as mock_api: + from .llm_client import _VERDICT_VOTES + mock_api.side_effect = [ {"content": [{"text": json.dumps(pass1_response)}]}, # Pass 1 - {"content": [{"text": json.dumps(pass2_response)}]}, # Pass 2 + {"content": [{"text": json.dumps(critique_response)}]}, # Pass 1.5 + ] + [ + {"content": [{"text": json.dumps(pass2_response)}]} # Pass 2 + for _ in range(_VERDICT_VOTES) + ] + [ {"content": [{"text": json.dumps(pass3_response)}]}, # Pass 3 ] result = classify_all( @@ -1263,7 +1311,8 @@ def test_full_pipeline(self): assert result.suggestion is not None assert len(result.suggestion.suggestions) == 1 assert result.suggestion.suggestions[0].description == "Add is_protocol() check" - assert mock_api.call_count == 3 + # Pass 1 + Pass 1.5 + Pass 2 (N votes) + Pass 3 + assert mock_api.call_count == 2 + _VERDICT_VOTES + 1 class TestSuggestionInMarkdownOutput: