Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 171 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Tuple

Expand All @@ -28,6 +29,7 @@
from pruna.data.datasets.prompt import (
setup_drawbench_dataset,
setup_genai_bench_dataset,
setup_long_text_bench_dataset,
setup_parti_prompts_dataset,
)
from pruna.data.datasets.question_answering import setup_polyglot_dataset
Expand Down Expand Up @@ -97,8 +99,176 @@
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}


@dataclass
class BenchmarkInfo:
"""
Metadata for a benchmark dataset.

Parameters
----------
name : str
Internal identifier for the benchmark.
display_name : str
Human-readable name for display purposes.
description : str
Description of what the benchmark evaluates.
metrics : list[str]
List of metric names used for evaluation.
task_type : str
Type of task the benchmark evaluates (e.g., 'text_to_image').
subsets : list[str]
Optional list of benchmark subset names.
"""

name: str
display_name: str
description: str
metrics: list[str]
task_type: str
subsets: list[str] = field(default_factory=list)


benchmark_info: dict[str, BenchmarkInfo] = {
"PartiPrompts": BenchmarkInfo(
name="parti_prompts",
display_name="Parti Prompts",
description=(
"Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects "
"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
"across different domains and difficulty levels."
),
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
subsets=[
"Abstract",
"Animals",
"Artifacts",
"Arts",
"Food & Beverage",
"Illustrations",
"Indoor Scenes",
"Outdoor Scenes",
"People",
"Produce & Plants",
"Vehicles",
"World Knowledge",
"Basic",
"Complex",
"Fine-grained Detail",
"Imagination",
"Linguistic Structures",
"Perspective",
"Properties & Positioning",
"Quantity",
"Simple Detail",
"Style & Format",
"Writing & Symbols",
],
),
"DrawBench": BenchmarkInfo(
name="drawbench",
display_name="DrawBench",
description="A comprehensive benchmark for evaluating text-to-image generation models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"GenAIBench": BenchmarkInfo(
name="genai_bench",
display_name="GenAI Bench",
description="A benchmark for evaluating generative AI models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"VBench": BenchmarkInfo(
name="vbench",
display_name="VBench",
description="A benchmark for evaluating video generation models.",
metrics=["clip_score"],
task_type="text_to_video",
),
"LongTextBench": BenchmarkInfo(
name="long_text_bench",
display_name="Long Text Bench",
description=(
"Extended detail-rich prompts averaging 284.89 tokens with evaluation dimensions of "
"character attributes, structured locations, scene attributes, and spatial relationships "
"to test compositional reasoning under long prompt complexity."
),
metrics=["clip_score", "clipiqa"],
task_type="text_to_image",
),
"COCO": BenchmarkInfo(
name="coco",
display_name="COCO",
description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.",
metrics=["fid", "clip_score", "clipiqa"],
task_type="text_to_image",
),
"ImageNet": BenchmarkInfo(
name="imagenet",
display_name="ImageNet",
description="Large-scale image classification benchmark with 1000 classes.",
metrics=["accuracy"],
task_type="image_classification",
),
"WikiText": BenchmarkInfo(
name="wikitext",
display_name="WikiText",
description="Language modeling benchmark based on Wikipedia articles.",
metrics=["perplexity"],
task_type="text_generation",
),
}


def list_benchmarks(task_type: str | None = None) -> list[str]:
"""
List available benchmark names.

Parameters
----------
task_type : str | None
Filter by task type (e.g., 'text_to_image', 'text_to_video').
If None, returns all benchmarks.

Returns
-------
list[str]
List of benchmark names.
"""
if task_type is None:
return list(benchmark_info.keys())
return [name for name, info in benchmark_info.items() if info.task_type == task_type]


def get_benchmark_info(name: str) -> BenchmarkInfo:
"""
Get benchmark metadata by name.

Parameters
----------
name : str
The benchmark name.

Returns
-------
BenchmarkInfo
The benchmark metadata.

Raises
------
KeyError
If benchmark name is not found.
"""
if name not in benchmark_info:
available = ", ".join(benchmark_info.keys())
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
return benchmark_info[name]
58 changes: 56 additions & 2 deletions src/pruna/data/datasets/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
return ds.select([0]), ds.select([0]), ds


def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_parti_prompts_dataset(
seed: int,
category: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Parti Prompts dataset.

Expand All @@ -51,18 +55,68 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
----------
seed : int
The seed to use.
category : str | None
Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts,
Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People,
Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex,
Fine-grained Detail, Imagination, Linguistic Structures, Perspective,
Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Parti Prompts dataset.
The Parti Prompts dataset (dummy train, dummy val, test).
"""
ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]

if category is not None:
ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category)

ds = ds.shuffle(seed=seed)

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

ds = ds.rename_column("Prompt", "text")
pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds


def setup_long_text_bench_dataset(
seed: int,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Long Text Bench dataset.

License: Apache 2.0

Parameters
----------
seed : int
The seed to use.
num_samples : int | None
Maximum number of samples to return. If None, returns all 160 samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Long Text Bench dataset (dummy train, dummy val, test).
"""
ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index]
ds = ds.rename_column("text", "text_content")
ds = ds.rename_column("prompt", "text")
ds = ds.shuffle(seed=seed)

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

pruna_logger.info("LongTextBench is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds


def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the GenAI Bench dataset.
Expand Down
38 changes: 35 additions & 3 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Callable

import pytest
from transformers import AutoTokenizer
from datasets import Dataset
from torch.utils.data import TensorDataset
import torch
from transformers import AutoTokenizer

from pruna.data import BenchmarkInfo, benchmark_info
from pruna.data.datasets.image import setup_imagenet_dataset
from pruna.data.pruna_datamodule import PrunaDataModule

Expand Down Expand Up @@ -45,6 +45,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
pytest.param("GenAIBench", dict(), marks=pytest.mark.slow),
pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),
pytest.param("VBench", dict(), marks=pytest.mark.slow),
pytest.param("LongTextBench", dict(), marks=pytest.mark.slow),
],
)
def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None:
Expand Down Expand Up @@ -80,3 +81,34 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar
assert labels.dtype == torch.int64
# iterate through the dataloaders
iterate_dataloaders(datamodule)



@pytest.mark.slow
def test_parti_prompts_with_category_filter():
"""Test PartiPrompts loading with category filter."""
dm = PrunaDataModule.from_string(
"PartiPrompts", category="Animals", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all(aux["Category"] == "Animals" for aux in auxiliaries)


@pytest.mark.slow
def test_long_text_bench_auxiliaries():
"""Test LongTextBench loading with auxiliaries."""
dm = PrunaDataModule.from_string(
"LongTextBench", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all("text_content" in aux for aux in auxiliaries)