Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion tests/acceptance/test_evals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest

from transformer_lens.evals import IOIDataset, ioi_eval
from transformer_lens.evals import (
IOIDataset,
ioi_eval,
make_mmlu_data_loader,
mmlu_eval,
)
from transformer_lens.HookedTransformer import HookedTransformer


Expand Down Expand Up @@ -70,3 +75,65 @@ def test_inverted_template(model):
results = ioi_eval(model, dataset=ds)
assert results["Logit Difference"] < -2.0
assert results["Accuracy"] <= 0.01


def test_mmlu_data_loader_single_subject():
"""
Test loading MMLU data for a single subject.
"""
data = make_mmlu_data_loader(subjects="abstract_algebra", num_samples=5)
assert len(data) == 5
assert all(isinstance(d, dict) for d in data)
assert all("question" in d for d in data)
assert all("choices" in d for d in data)
assert all("answer" in d for d in data)
assert all("subject" in d for d in data)
assert all(len(d["choices"]) == 4 for d in data)
assert all(d["subject"] == "abstract_algebra" for d in data)


def test_mmlu_data_loader_multiple_subjects():
"""
Test loading MMLU data for multiple subjects.
"""
subjects = ["abstract_algebra", "anatomy"]
data = make_mmlu_data_loader(subjects=subjects, num_samples=3)
assert len(data) == 6 # 3 samples per subject
subjects_in_data = {d["subject"] for d in data}
assert subjects_in_data == set(subjects)


def test_mmlu_data_loader_invalid_subject():
"""
Test that invalid subject names raise an error.
"""
with pytest.raises(ValueError, match="Invalid subject"):
make_mmlu_data_loader(subjects="invalid_subject_name")


def test_mmlu_eval_single_subject(model):
"""
Test MMLU evaluation on a single subject with a small number of samples.
Uses a small model and few samples for fast CI execution.
"""
results = mmlu_eval(model, subjects="abstract_algebra", num_samples=5)
assert "accuracy" in results
assert "num_correct" in results
assert "num_total" in results
assert "subject_scores" in results
assert 0 <= results["accuracy"] <= 1
assert results["num_total"] == 5
assert results["num_correct"] <= results["num_total"]
assert "abstract_algebra" in results["subject_scores"]


def test_mmlu_eval_multiple_subjects(model):
"""
Test MMLU evaluation on multiple subjects.
"""
subjects = ["abstract_algebra", "anatomy"]
results = mmlu_eval(model, subjects=subjects, num_samples=3)
assert results["num_total"] == 6 # 3 samples per subject
assert len(results["subject_scores"]) == 2
assert all(subject in results["subject_scores"] for subject in subjects)
assert all(0 <= acc <= 1 for acc in results["subject_scores"].values())
290 changes: 289 additions & 1 deletion transformer_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import random
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import einops
import torch
Expand Down Expand Up @@ -85,6 +85,160 @@ def make_code_data_loader(tokenizer, batch_size=8):
return data_loader


# All 57 subjects available in the MMLU benchmark
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]

MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"]


def make_mmlu_data_loader(
subjects: Optional[Union[str, List[str]]] = None,
split: str = "test",
num_samples: Optional[int] = None,
):
"""
Load MMLU (Massive Multitask Language Understanding) dataset.

MMLU tests model performance on 57 subjects across STEM, humanities, social sciences,
and more. Each question is multiple choice with 4 options (A, B, C, D).

Paper: https://arxiv.org/abs/2009.03300
Dataset: https://huggingface.co/datasets/cais/mmlu

Args:
subjects: Subject(s) to evaluate on. Can be:
- None: Use all 57 subjects (default)
- str: Single subject name (e.g., "abstract_algebra")
- List[str]: Multiple subjects
split: Which split to use - "test", "validation", or "dev". Default is "test".
num_samples: Optional limit on number of samples per subject. If None, uses all samples.

Returns:
List of dictionaries with MMLU examples, each containing:
- "question": str
- "choices": List[str] (4 choices)
- "answer": int (0-3, correct choice index)
- "subject": str

Examples:

.. code-block:: python

>>> from transformer_lens.evals import make_mmlu_data_loader

>>> # Load specific subject
>>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP

>>> # Load multiple subjects
>>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP
... subjects=["abstract_algebra", "astronomy", "college_chemistry"]
... )
"""
# Handle subjects parameter
if subjects is None:
subjects_to_load = MMLU_SUBJECTS
elif isinstance(subjects, str):
subjects_to_load = [subjects]
else:
subjects_to_load = list(subjects)

# Validate subjects
invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS)
if invalid_subjects:
raise ValueError(
f"Invalid subject(s): {invalid_subjects}. "
f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}"
)

# Load data for each subject
mmlu_data = []
for subject in subjects_to_load:
try:
# Load dataset for this subject
dataset = load_dataset("cais/mmlu", subject, split=split)

# Limit samples if requested
samples_to_take = (
len(dataset) if num_samples is None else min(num_samples, len(dataset))
)

# Convert to our format
for i in range(samples_to_take):
example = dataset[i]
mmlu_data.append(
{
"question": example["question"],
"choices": example["choices"],
"answer": example["answer"],
"subject": subject,
}
)
except Exception as e:
print(f"Warning: Could not load subject '{subject}': {e}")
continue

print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)")
return mmlu_data


DATASET_NAMES = ["wiki", "owt", "pile", "code"]
DATASET_LOADERS = [
make_wiki_data_loader,
Expand Down Expand Up @@ -334,3 +488,137 @@ def collate(samples):
"Logit Difference": total_logit_diff / len(dataset),
"Accuracy": total_correct / len(dataset),
}


@torch.inference_mode()
def mmlu_eval(
model,
tokenizer=None,
subjects: Optional[Union[str, List[str]]] = None,
split: str = "test",
num_samples: Optional[int] = None,
):
"""Evaluate a model on the MMLU benchmark.

MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models
on 57 subjects across STEM, humanities, social sciences, and more. Each question is
multiple-choice with 4 options.

For each question, all four answer choices (A-D) are shown in the prompt and the model's
log probability for each answer letter token is compared. This is a zero-shot evaluation;
standard MMLU benchmarks typically use 5-shot prompting for higher accuracy.

Paper: https://arxiv.org/abs/2009.03300

Args:
model: HookedTransformer model to evaluate.
tokenizer: Tokenizer to use. If None, uses model.tokenizer.
subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject
string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names.
split: Which split to use - "test", "validation", or "dev". Default is "test".
num_samples: Optional limit on number of samples per subject. If None, uses all samples.

Returns:
Dictionary containing:
- "accuracy": Overall accuracy (0-1)
- "num_correct": Number of correct predictions
- "num_total": Total number of questions
- "subject_scores": Dict mapping subject names to their accuracy

Examples:

.. code-block:: python

>>> from transformer_lens import HookedTransformer
>>> from transformer_lens.evals import mmlu_eval

>>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
>>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP
>>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP
"""
if tokenizer is None:
tokenizer = model.tokenizer

# Load MMLU data
mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples)

if len(mmlu_data) == 0:
raise ValueError("No MMLU data loaded. Check your subjects parameter.")

# Precompute token IDs for answer letters A, B, C, D
# Done once here instead of per-question for efficiency
answer_letter_token_ids = []
for letter in MMLU_ANSWER_LETTERS:
# Try with space prefix first (how it appears after "Answer:")
token_ids = tokenizer.encode(" " + letter, add_special_tokens=False)
if len(token_ids) == 1:
answer_letter_token_ids.append(token_ids[0])
else:
# Fallback to without space
token_ids = tokenizer.encode(letter, add_special_tokens=False)
answer_letter_token_ids.append(token_ids[0])

# Track results
num_correct = 0
num_total = 0
subject_correct: Dict[str, int] = {}
subject_total: Dict[str, int] = {}

# Process examples
for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"):
question = example["question"]
choices = example["choices"]
correct_answer = example["answer"]
subject = example["subject"]

# Initialize subject tracking
if subject not in subject_correct:
subject_correct[subject] = 0
subject_total[subject] = 0

# Format prompt with all choices shown (standard MMLU format)
prompt = f"Question: {question}\n"
prompt += "Choices:\n"
for idx, choice_text in enumerate(choices):
letter = chr(65 + idx) # A, B, C, D
prompt += f"{letter}. {choice_text}\n"
prompt += "Answer:"

# Tokenize the prompt
tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device)

# Get logits
logits = model(tokens, return_type="logits")

# Get log probabilities at the last position (predicting the answer letter)
last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1)

# Score each answer choice by its letter token probability
choice_log_probs = []
for idx in range(len(choices)):
token_id = answer_letter_token_ids[idx]
choice_log_probs.append(last_log_probs[token_id].item())

# Select the choice with highest log probability
predicted_answer = choice_log_probs.index(max(choice_log_probs))

# Check if correct
is_correct = predicted_answer == correct_answer
num_correct += int(is_correct)
num_total += 1
subject_correct[subject] += int(is_correct)
subject_total[subject] += 1

# Compute accuracies
overall_accuracy = num_correct / num_total if num_total > 0 else 0.0
subject_scores = {
subject: subject_correct[subject] / subject_total[subject]
for subject in subject_correct.keys()
}

return {
"accuracy": overall_accuracy,
"num_correct": num_correct,
"num_total": num_total,
"subject_scores": subject_scores,
}
Loading