From 98113e248c328eb6772c3da430e5bee6a38e4451 Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Thu, 12 Feb 2026 16:00:24 -0800 Subject: [PATCH 1/3] adding MMLU to evals, updating corresponding tests --- tests/acceptance/test_evals.py | 69 +++++++- transformer_lens/evals.py | 307 ++++++++++++++++++++++++++++++++- 2 files changed, 374 insertions(+), 2 deletions(-) diff --git a/tests/acceptance/test_evals.py b/tests/acceptance/test_evals.py index 5936466f6..0644651b3 100644 --- a/tests/acceptance/test_evals.py +++ b/tests/acceptance/test_evals.py @@ -1,6 +1,11 @@ import pytest -from transformer_lens.evals import IOIDataset, ioi_eval +from transformer_lens.evals import ( + IOIDataset, + ioi_eval, + make_mmlu_data_loader, + mmlu_eval, +) from transformer_lens.HookedTransformer import HookedTransformer @@ -70,3 +75,65 @@ def test_inverted_template(model): results = ioi_eval(model, dataset=ds) assert results["Logit Difference"] < -2.0 assert results["Accuracy"] <= 0.01 + + +def test_mmlu_data_loader_single_subject(): + """ + Test loading MMLU data for a single subject. + """ + data = make_mmlu_data_loader(subjects="abstract_algebra", num_samples=5) + assert len(data) == 5 + assert all(isinstance(d, dict) for d in data) + assert all("question" in d for d in data) + assert all("choices" in d for d in data) + assert all("answer" in d for d in data) + assert all("subject" in d for d in data) + assert all(len(d["choices"]) == 4 for d in data) + assert all(d["subject"] == "abstract_algebra" for d in data) + + +def test_mmlu_data_loader_multiple_subjects(): + """ + Test loading MMLU data for multiple subjects. + """ + subjects = ["abstract_algebra", "anatomy"] + data = make_mmlu_data_loader(subjects=subjects, num_samples=3) + assert len(data) == 6 # 3 samples per subject + subjects_in_data = {d["subject"] for d in data} + assert subjects_in_data == set(subjects) + + +def test_mmlu_data_loader_invalid_subject(): + """ + Test that invalid subject names raise an error. + """ + with pytest.raises(ValueError, match="Invalid subject"): + make_mmlu_data_loader(subjects="invalid_subject_name") + + +def test_mmlu_eval_single_subject(model): + """ + Test MMLU evaluation on a single subject with a small number of samples. + Uses a small model and few samples for fast CI execution. + """ + results = mmlu_eval(model, subjects="abstract_algebra", num_samples=5, device="cpu") + assert "accuracy" in results + assert "num_correct" in results + assert "num_total" in results + assert "subject_scores" in results + assert 0 <= results["accuracy"] <= 1 + assert results["num_total"] == 5 + assert results["num_correct"] <= results["num_total"] + assert "abstract_algebra" in results["subject_scores"] + + +def test_mmlu_eval_multiple_subjects(model): + """ + Test MMLU evaluation on multiple subjects. + """ + subjects = ["abstract_algebra", "anatomy"] + results = mmlu_eval(model, subjects=subjects, num_samples=3, device="cpu") + assert results["num_total"] == 6 # 3 samples per subject + assert len(results["subject_scores"]) == 2 + assert all(subject in results["subject_scores"] for subject in subjects) + assert all(0 <= acc <= 1 for acc in results["subject_scores"].values()) diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py index b77c727c5..9f5eb3c28 100644 --- a/transformer_lens/evals.py +++ b/transformer_lens/evals.py @@ -6,7 +6,7 @@ """ import random -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import einops import torch @@ -85,6 +85,160 @@ def make_code_data_loader(tokenizer, batch_size=8): return data_loader +# All 57 subjects available in the MMLU benchmark +MMLU_SUBJECTS = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", +] + +MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"] + + +def make_mmlu_data_loader( + subjects: Optional[Union[str, List[str]]] = None, + split: str = "test", + num_samples: Optional[int] = None, +): + """ + Load MMLU (Massive Multitask Language Understanding) dataset. + + MMLU tests model performance on 57 subjects across STEM, humanities, social sciences, + and more. Each question is multiple choice with 4 options (A, B, C, D). + + Paper: https://arxiv.org/abs/2009.03300 + Dataset: https://huggingface.co/datasets/cais/mmlu + + Args: + subjects: Subject(s) to evaluate on. Can be: + - None: Use all 57 subjects (default) + - str: Single subject name (e.g., "abstract_algebra") + - List[str]: Multiple subjects + split: Which split to use - "test", "validation", or "dev". Default is "test". + num_samples: Optional limit on number of samples per subject. If None, uses all samples. + + Returns: + List of dictionaries with MMLU examples, each containing: + - "question": str + - "choices": List[str] (4 choices) + - "answer": int (0-3, correct choice index) + - "subject": str + + Examples: + + .. code-block:: python + + >>> from transformer_lens.evals import make_mmlu_data_loader + + >>> # Load specific subject + >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") + + >>> # Load multiple subjects + >>> mmlu_data = make_mmlu_data_loader( + ... subjects=["abstract_algebra", "astronomy", "college_chemistry"] + ... ) + """ + # Handle subjects parameter + if subjects is None: + subjects_to_load = MMLU_SUBJECTS + elif isinstance(subjects, str): + subjects_to_load = [subjects] + else: + subjects_to_load = list(subjects) + + # Validate subjects + invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS) + if invalid_subjects: + raise ValueError( + f"Invalid subject(s): {invalid_subjects}. " + f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}" + ) + + # Load data for each subject + mmlu_data = [] + for subject in subjects_to_load: + try: + # Load dataset for this subject + dataset = load_dataset("cais/mmlu", subject, split=split) + + # Limit samples if requested + samples_to_take = ( + len(dataset) if num_samples is None else min(num_samples, len(dataset)) + ) + + # Convert to our format + for i in range(samples_to_take): + example = dataset[i] + mmlu_data.append( + { + "question": example["question"], + "choices": example["choices"], + "answer": example["answer"], + "subject": subject, + } + ) + except Exception as e: + print(f"Warning: Could not load subject '{subject}': {e}") + continue + + print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)") + return mmlu_data + + DATASET_NAMES = ["wiki", "owt", "pile", "code"] DATASET_LOADERS = [ make_wiki_data_loader, @@ -334,3 +488,154 @@ def collate(samples): "Logit Difference": total_logit_diff / len(dataset), "Accuracy": total_correct / len(dataset), } + + +@torch.inference_mode() +def mmlu_eval( + model, + tokenizer=None, + subjects: Optional[Union[str, List[str]]] = None, + split: str = "test", + num_samples: Optional[int] = None, + device: str = "cuda", +): + """Evaluate a model on the MMLU benchmark. + + MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models + on 57 subjects across STEM, humanities, social sciences, and more. Each question is + multiple-choice with 4 options. + + For each question, all four answer choices (A-D) are shown in the prompt and the model's + log probability for each answer letter token is compared. This is a zero-shot evaluation; + standard MMLU benchmarks typically use 5-shot prompting for higher accuracy. + + Paper: https://arxiv.org/abs/2009.03300 + + Args: + model: HookedTransformer model to evaluate. + tokenizer: Tokenizer to use. If None, uses model.tokenizer. + subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject + string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names. + split: Which split to use - "test", "validation", or "dev". Default is "test". + num_samples: Optional limit on number of samples per subject. If None, uses all samples. + device: Device to run evaluation on. Default is "cuda". + + Returns: + Dictionary containing: + - "accuracy": Overall accuracy (0-1) + - "num_correct": Number of correct predictions + - "num_total": Total number of questions + - "subject_scores": Dict mapping subject names to their accuracy + + Examples: + + .. code-block:: python + + >>> from transformer_lens import HookedTransformer + >>> from transformer_lens.evals import mmlu_eval + + >>> model = HookedTransformer.from_pretrained("gpt2-small") + Loaded pretrained model gpt2-small into HookedTransformer + + >>> # Evaluate on a specific subject + >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) + >>> print(f"Accuracy: {results['accuracy']:.2%}") + Accuracy: 30.00% + + >>> # Evaluate on multiple subjects + >>> results = mmlu_eval( + ... model, + ... subjects=["astronomy", "college_mathematics"], + ... num_samples=20 + ... ) + >>> for subject, acc in results["subject_scores"].items(): + ... print(f"{subject}: {acc:.2%}") + astronomy: 35.00% + college_mathematics: 25.00% + """ + if tokenizer is None: + tokenizer = model.tokenizer + + # Load MMLU data + mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples) + + if len(mmlu_data) == 0: + raise ValueError("No MMLU data loaded. Check your subjects parameter.") + + # Precompute token IDs for answer letters A, B, C, D + # Done once here instead of per-question for efficiency + answer_letter_token_ids = [] + for letter in MMLU_ANSWER_LETTERS: + # Try with space prefix first (how it appears after "Answer:") + token_ids = tokenizer.encode(" " + letter, add_special_tokens=False) + if len(token_ids) == 1: + answer_letter_token_ids.append(token_ids[0]) + else: + # Fallback to without space + token_ids = tokenizer.encode(letter, add_special_tokens=False) + answer_letter_token_ids.append(token_ids[0]) + + # Track results + num_correct = 0 + num_total = 0 + subject_correct: Dict[str, int] = {} + subject_total: Dict[str, int] = {} + + # Process examples + for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"): + question = example["question"] + choices = example["choices"] + correct_answer = example["answer"] + subject = example["subject"] + + # Initialize subject tracking + if subject not in subject_correct: + subject_correct[subject] = 0 + subject_total[subject] = 0 + + # Format prompt with all choices shown (standard MMLU format) + prompt = f"Question: {question}\n" + prompt += "Choices:\n" + for idx, choice_text in enumerate(choices): + letter = chr(65 + idx) # A, B, C, D + prompt += f"{letter}. {choice_text}\n" + prompt += "Answer:" + + # Tokenize the prompt + tokens = tokenizer.encode(prompt, return_tensors="pt").to(device) + + # Get logits + logits = model(tokens, return_type="logits") + + # Get log probabilities at the last position (predicting the answer letter) + last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1) + + # Score each answer choice by its letter token probability + choice_log_probs = [] + for idx in range(len(choices)): + token_id = answer_letter_token_ids[idx] + choice_log_probs.append(last_log_probs[token_id].item()) + + # Select the choice with highest log probability + predicted_answer = choice_log_probs.index(max(choice_log_probs)) + + # Check if correct + is_correct = predicted_answer == correct_answer + num_correct += int(is_correct) + num_total += 1 + subject_correct[subject] += int(is_correct) + subject_total[subject] += 1 + + # Compute accuracies + overall_accuracy = num_correct / num_total if num_total > 0 else 0.0 + subject_scores = { + subject: subject_correct[subject] / subject_total[subject] + for subject in subject_correct.keys() + } + + return { + "accuracy": overall_accuracy, + "num_correct": num_correct, + "num_total": num_total, + "subject_scores": subject_scores, + } From 14a4fb11595ee7f4368c3344c622e399dfaa6a75 Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Mon, 23 Feb 2026 17:37:18 -0800 Subject: [PATCH 2/3] Fix docstring tests for MMLU functions Skip MMLU docstring examples in doctest runs since they require network access (HuggingFace dataset download) and may require GPU. Co-Authored-By: Claude Opus 4.6 --- transformer_lens/evals.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py index 9f5eb3c28..7d5d5c7d6 100644 --- a/transformer_lens/evals.py +++ b/transformer_lens/evals.py @@ -185,10 +185,10 @@ def make_mmlu_data_loader( >>> from transformer_lens.evals import make_mmlu_data_loader >>> # Load specific subject - >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") + >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP >>> # Load multiple subjects - >>> mmlu_data = make_mmlu_data_loader( + >>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP ... subjects=["abstract_algebra", "astronomy", "college_chemistry"] ... ) """ @@ -534,24 +534,9 @@ def mmlu_eval( >>> from transformer_lens import HookedTransformer >>> from transformer_lens.evals import mmlu_eval - >>> model = HookedTransformer.from_pretrained("gpt2-small") - Loaded pretrained model gpt2-small into HookedTransformer - - >>> # Evaluate on a specific subject - >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) - >>> print(f"Accuracy: {results['accuracy']:.2%}") - Accuracy: 30.00% - - >>> # Evaluate on multiple subjects - >>> results = mmlu_eval( - ... model, - ... subjects=["astronomy", "college_mathematics"], - ... num_samples=20 - ... ) - >>> for subject, acc in results["subject_scores"].items(): - ... print(f"{subject}: {acc:.2%}") - astronomy: 35.00% - college_mathematics: 25.00% + >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP + >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10, device="cpu") # doctest: +SKIP + >>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP """ if tokenizer is None: tokenizer = model.tokenizer From eae677f374fc18f693dca331b57cce74b8a6266c Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Fri, 6 Mar 2026 15:02:41 -0800 Subject: [PATCH 3/3] Remove device parameter from mmlu_eval, use model.cfg.device instead Address PR review feedback: use model.cfg.device internally instead of accepting a device parameter, consistent with ioi_eval. This prevents device mismatch errors when users forget to pass the correct device. --- tests/acceptance/test_evals.py | 4 ++-- transformer_lens/evals.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/acceptance/test_evals.py b/tests/acceptance/test_evals.py index 0644651b3..d55eefcef 100644 --- a/tests/acceptance/test_evals.py +++ b/tests/acceptance/test_evals.py @@ -116,7 +116,7 @@ def test_mmlu_eval_single_subject(model): Test MMLU evaluation on a single subject with a small number of samples. Uses a small model and few samples for fast CI execution. """ - results = mmlu_eval(model, subjects="abstract_algebra", num_samples=5, device="cpu") + results = mmlu_eval(model, subjects="abstract_algebra", num_samples=5) assert "accuracy" in results assert "num_correct" in results assert "num_total" in results @@ -132,7 +132,7 @@ def test_mmlu_eval_multiple_subjects(model): Test MMLU evaluation on multiple subjects. """ subjects = ["abstract_algebra", "anatomy"] - results = mmlu_eval(model, subjects=subjects, num_samples=3, device="cpu") + results = mmlu_eval(model, subjects=subjects, num_samples=3) assert results["num_total"] == 6 # 3 samples per subject assert len(results["subject_scores"]) == 2 assert all(subject in results["subject_scores"] for subject in subjects) diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py index 7d5d5c7d6..8137d0740 100644 --- a/transformer_lens/evals.py +++ b/transformer_lens/evals.py @@ -497,7 +497,6 @@ def mmlu_eval( subjects: Optional[Union[str, List[str]]] = None, split: str = "test", num_samples: Optional[int] = None, - device: str = "cuda", ): """Evaluate a model on the MMLU benchmark. @@ -518,7 +517,6 @@ def mmlu_eval( string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names. split: Which split to use - "test", "validation", or "dev". Default is "test". num_samples: Optional limit on number of samples per subject. If None, uses all samples. - device: Device to run evaluation on. Default is "cuda". Returns: Dictionary containing: @@ -535,7 +533,7 @@ def mmlu_eval( >>> from transformer_lens.evals import mmlu_eval >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP - >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10, device="cpu") # doctest: +SKIP + >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP >>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP """ if tokenizer is None: @@ -587,7 +585,7 @@ def mmlu_eval( prompt += "Answer:" # Tokenize the prompt - tokens = tokenizer.encode(prompt, return_tensors="pt").to(device) + tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device) # Get logits logits = model(tokens, return_type="logits")