Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 171 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Tuple

Expand All @@ -28,6 +29,7 @@
from pruna.data.datasets.prompt import (
setup_drawbench_dataset,
setup_genai_bench_dataset,
setup_geneval_dataset,
setup_parti_prompts_dataset,
)
from pruna.data.datasets.question_answering import setup_polyglot_dataset
Expand Down Expand Up @@ -97,8 +99,176 @@
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}


@dataclass
class BenchmarkInfo:
"""
Metadata for a benchmark dataset.

Parameters
----------
name : str
Internal identifier for the benchmark.
display_name : str
Human-readable name for display purposes.
description : str
Description of what the benchmark evaluates.
metrics : list[str]
List of metric names used for evaluation.
task_type : str
Type of task the benchmark evaluates (e.g., 'text_to_image').
subsets : list[str]
Optional list of benchmark subset names.
"""

name: str
display_name: str
description: str
metrics: list[str]
task_type: str
subsets: list[str] = field(default_factory=list)


benchmark_info: dict[str, BenchmarkInfo] = {
"PartiPrompts": BenchmarkInfo(
name="parti_prompts",
display_name="Parti Prompts",
description=(
"Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects "
"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
"across different domains and difficulty levels."
),
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
subsets=[
"Abstract",
"Animals",
"Artifacts",
"Arts",
"Food & Beverage",
"Illustrations",
"Indoor Scenes",
"Outdoor Scenes",
"People",
"Produce & Plants",
"Vehicles",
"World Knowledge",
"Basic",
"Complex",
"Fine-grained Detail",
"Imagination",
"Linguistic Structures",
"Perspective",
"Properties & Positioning",
"Quantity",
"Simple Detail",
"Style & Format",
"Writing & Symbols",
],
),
"DrawBench": BenchmarkInfo(
name="drawbench",
display_name="DrawBench",
description="A comprehensive benchmark for evaluating text-to-image generation models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"GenAIBench": BenchmarkInfo(
name="genai_bench",
display_name="GenAI Bench",
description="A benchmark for evaluating generative AI models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"VBench": BenchmarkInfo(
name="vbench",
display_name="VBench",
description="A benchmark for evaluating video generation models.",
metrics=["clip_score"],
task_type="text_to_video",
),
"GenEval": BenchmarkInfo(
name="geneval",
display_name="GenEval",
description=(
"Fine-grained compositional evaluation across object co-occurrence, positioning, "
"counting, and color binding to identify specific failure modes in text-to-image alignment."
),
metrics=["accuracy"],
task_type="text_to_image",
subsets=["single_object", "two_object", "counting", "colors", "position", "color_attr"],
),
"COCO": BenchmarkInfo(
name="coco",
display_name="COCO",
description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.",
metrics=["fid", "clip_score", "clipiqa"],
task_type="text_to_image",
),
"ImageNet": BenchmarkInfo(
name="imagenet",
display_name="ImageNet",
description="Large-scale image classification benchmark with 1000 classes.",
metrics=["accuracy"],
task_type="image_classification",
),
"WikiText": BenchmarkInfo(
name="wikitext",
display_name="WikiText",
description="Language modeling benchmark based on Wikipedia articles.",
metrics=["perplexity"],
task_type="text_generation",
),
}


def list_benchmarks(task_type: str | None = None) -> list[str]:
"""
List available benchmark names.

Parameters
----------
task_type : str | None
Filter by task type (e.g., 'text_to_image', 'text_to_video').
If None, returns all benchmarks.

Returns
-------
list[str]
List of benchmark names.
"""
if task_type is None:
return list(benchmark_info.keys())
return [name for name, info in benchmark_info.items() if info.task_type == task_type]


def get_benchmark_info(name: str) -> BenchmarkInfo:
"""
Get benchmark metadata by name.

Parameters
----------
name : str
The benchmark name.

Returns
-------
BenchmarkInfo
The benchmark metadata.

Raises
------
KeyError
If benchmark name is not found.
"""
if name not in benchmark_info:
available = ", ".join(benchmark_info.keys())
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
return benchmark_info[name]
110 changes: 108 additions & 2 deletions src/pruna/data/datasets/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
return ds.select([0]), ds.select([0]), ds


def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_parti_prompts_dataset(
seed: int,
category: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Parti Prompts dataset.

Expand All @@ -51,18 +55,120 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
----------
seed : int
The seed to use.
category : str | None
Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts,
Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People,
Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex,
Fine-grained Detail, Imagination, Linguistic Structures, Perspective,
Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Parti Prompts dataset.
The Parti Prompts dataset (dummy train, dummy val, test).
"""
ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]

if category is not None:
ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category)

ds = ds.shuffle(seed=seed)

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

ds = ds.rename_column("Prompt", "text")
pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds


GENEVAL_CATEGORIES = ["single_object", "two_object", "counting", "colors", "position", "color_attr"]


def _generate_geneval_question(entry: dict) -> list[str]:
"""Generate evaluation questions from GenEval metadata."""
tag = entry.get("tag", "")
include = entry.get("include", [])
questions = []

for obj in include:
cls = obj.get("class", "")
if "color" in obj:
questions.append(f"Does the image contain a {obj['color']} {cls}?")
elif "count" in obj:
questions.append(f"Does the image contain exactly {obj['count']} {cls}(s)?")
else:
questions.append(f"Does the image contain a {cls}?")

if tag == "position" and len(include) >= 2:
a_cls = include[0].get("class", "")
b_cls = include[1].get("class", "")
pos = include[1].get("position")
if pos and pos[0]:
questions.append(f"Is the {b_cls} {pos[0]} the {a_cls}?")

return questions


def setup_geneval_dataset(
seed: int,
category: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the GenEval benchmark dataset.

License: MIT

Parameters
----------
seed : int
The seed to use.
category : str | None
Filter by category. Available: single_object, two_object, counting, colors, position, color_attr.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The GenEval dataset (dummy train, dummy val, test).
"""
import json

import requests

url = "https://raw.githubusercontent.com/djghosh13/geneval/d927da8e42fde2b1b5cd743da4df5ff83c1654ff/prompts/evaluation_metadata.jsonl"
response = requests.get(url)
data = [json.loads(line) for line in response.text.splitlines()]

if category is not None:
if category not in GENEVAL_CATEGORIES:
raise ValueError(f"Invalid category: {category}. Must be one of {GENEVAL_CATEGORIES}")
data = [entry for entry in data if entry.get("tag") == category]

records = []
for entry in data:
questions = _generate_geneval_question(entry)
records.append({
"text": entry["prompt"],
"tag": entry.get("tag", ""),
"questions": questions,
"include": entry.get("include", []),
})

ds = Dataset.from_list(records)
ds = ds.shuffle(seed=seed)

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

pruna_logger.info("GenEval is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds


def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the GenAI Bench dataset.
Expand Down
39 changes: 36 additions & 3 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Callable

import pytest
from transformers import AutoTokenizer
from datasets import Dataset
from torch.utils.data import TensorDataset
import torch
from transformers import AutoTokenizer

from pruna.data import BenchmarkInfo, benchmark_info
from pruna.data.datasets.image import setup_imagenet_dataset
from pruna.data.pruna_datamodule import PrunaDataModule

Expand Down Expand Up @@ -45,6 +45,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
pytest.param("GenAIBench", dict(), marks=pytest.mark.slow),
pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),
pytest.param("VBench", dict(), marks=pytest.mark.slow),
pytest.param("GenEval", dict(), marks=pytest.mark.slow),
],
)
def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None:
Expand Down Expand Up @@ -80,3 +81,35 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar
assert labels.dtype == torch.int64
# iterate through the dataloaders
iterate_dataloaders(datamodule)



@pytest.mark.slow
def test_parti_prompts_with_category_filter():
"""Test PartiPrompts loading with category filter."""
dm = PrunaDataModule.from_string(
"PartiPrompts", category="Animals", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all(aux["Category"] == "Animals" for aux in auxiliaries)


@pytest.mark.slow
def test_geneval_with_category_filter():
"""Test GenEval loading with category filter."""
dm = PrunaDataModule.from_string(
"GenEval", category="counting", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all(aux["tag"] == "counting" for aux in auxiliaries)
assert all("questions" in aux for aux in auxiliaries)