Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 158 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Tuple

Expand Down Expand Up @@ -97,8 +98,164 @@
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}


@dataclass
class BenchmarkInfo:
"""
Metadata for a benchmark dataset.

Parameters
----------
name : str
Internal identifier for the benchmark.
display_name : str
Human-readable name for display purposes.
description : str
Description of what the benchmark evaluates.
metrics : list[str]
List of metric names used for evaluation.
task_type : str
Type of task the benchmark evaluates (e.g., 'text_to_image').
subsets : list[str]
Optional list of benchmark subset names.
"""

name: str
display_name: str
description: str
metrics: list[str]
task_type: str
subsets: list[str] = field(default_factory=list)


benchmark_info: dict[str, BenchmarkInfo] = {
"PartiPrompts": BenchmarkInfo(
name="parti_prompts",
display_name="Parti Prompts",
description=(
"Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects "
"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
"across different domains and difficulty levels."
),
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should tightly couple benchmarking datasets with metrics. I think benchmarks should have their datasets available as PrunaDataModules, and the metrics for the Benchmarks should be Pruna Metrics. This way we can give the user the flexibility to use whichever dataset with whichever metric they choose, how do you feel?

task_type="text_to_image",
subsets=[
"Abstract",
"Animals",
"Artifacts",
"Arts",
"Food & Beverage",
"Illustrations",
"Indoor Scenes",
"Outdoor Scenes",
"People",
"Produce & Plants",
"Vehicles",
"World Knowledge",
"Basic",
"Complex",
"Fine-grained Detail",
"Imagination",
"Linguistic Structures",
"Perspective",
"Properties & Positioning",
"Quantity",
"Simple Detail",
"Style & Format",
"Writing & Symbols",
],
),
"DrawBench": BenchmarkInfo(
name="drawbench",
display_name="DrawBench",
description="A comprehensive benchmark for evaluating text-to-image generation models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"GenAIBench": BenchmarkInfo(
name="genai_bench",
display_name="GenAI Bench",
description="A benchmark for evaluating generative AI models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"VBench": BenchmarkInfo(
name="vbench",
display_name="VBench",
description="A benchmark for evaluating video generation models.",
metrics=["clip_score"],
task_type="text_to_video",
),
"COCO": BenchmarkInfo(
name="coco",
display_name="COCO",
description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.",
metrics=["fid", "clip_score", "clipiqa"],
task_type="text_to_image",
),
"ImageNet": BenchmarkInfo(
name="imagenet",
display_name="ImageNet",
description="Large-scale image classification benchmark with 1000 classes.",
metrics=["accuracy"],
task_type="image_classification",
),
"WikiText": BenchmarkInfo(
name="wikitext",
display_name="WikiText",
description="Language modeling benchmark based on Wikipedia articles.",
metrics=["perplexity"],
task_type="text_generation",
),
}


def list_benchmarks(task_type: str | None = None) -> list[str]:
"""
List available benchmark names.

Parameters
----------
task_type : str | None
Filter by task type (e.g., 'text_to_image', 'text_to_video').
If None, returns all benchmarks.

Returns
-------
list[str]
List of benchmark names.
"""
if task_type is None:
return list(benchmark_info.keys())
return [name for name, info in benchmark_info.items() if info.task_type == task_type]


def get_benchmark_info(name: str) -> BenchmarkInfo:
"""
Get benchmark metadata by name.

Parameters
----------
name : str
The benchmark name.

Returns
-------
BenchmarkInfo
The benchmark metadata.

Raises
------
KeyError
If benchmark name is not found.
"""
if name not in benchmark_info:
available = ", ".join(benchmark_info.keys())
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
return benchmark_info[name]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused functions defined but never called

Medium Severity

list_benchmarks() and get_benchmark_info() are defined but never called anywhere in the codebase. The PR description mentions a from_benchmark method that would presumably use these, but it's not implemented in this PR.

Fix in Cursor Fix in Web

25 changes: 23 additions & 2 deletions src/pruna/data/datasets/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
return ds.select([0]), ds.select([0]), ds


def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_parti_prompts_dataset(
seed: int,
category: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Parti Prompts dataset.

Expand All @@ -51,13 +55,30 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
----------
seed : int
The seed to use.
category : str | None
Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts,
Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People,
Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex,
Fine-grained Detail, Imagination, Linguistic Structures, Perspective,
Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Parti Prompts dataset.
The Parti Prompts dataset (dummy train, dummy val, test).
"""
ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]

if category is not None:
ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Category filter fails silently when list is passed

Medium Severity

The setup_parti_prompts_dataset function's category parameter only accepts str | None, but PrunaDataModule.from_string accepts category: str | list[str] | None. When a list is passed, the filter x["Category"] == category or x["Challenge"] == category compares strings against a list, which always evaluates to False. This silently filters out all records, resulting in an empty dataset that causes ds.select([0]) to fail with an index error.

Fix in Cursor Fix in Web


ds = ds.shuffle(seed=seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are only creating test set how do you feel about not shuffling the data?


if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

ds = ds.rename_column("Prompt", "text")
pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds
Expand Down
22 changes: 19 additions & 3 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Callable

import pytest
from transformers import AutoTokenizer
from datasets import Dataset
from torch.utils.data import TensorDataset
import torch
from transformers import AutoTokenizer

from pruna.data import BenchmarkInfo, benchmark_info
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports in test file

Low Severity

BenchmarkInfo and benchmark_info are imported but never used in the test file. These imports should be removed.

Fix in Cursor Fix in Web

from pruna.data.datasets.image import setup_imagenet_dataset
from pruna.data.pruna_datamodule import PrunaDataModule

Expand Down Expand Up @@ -80,3 +80,19 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar
assert labels.dtype == torch.int64
# iterate through the dataloaders
iterate_dataloaders(datamodule)



@pytest.mark.slow
def test_parti_prompts_with_category_filter():
"""Test PartiPrompts loading with category filter."""
dm = PrunaDataModule.from_string(
"PartiPrompts", category="Animals", dataloader_args={"batch_size": 4}
)
dm.limit_datasets(10)
batch = next(iter(dm.test_dataloader()))
prompts, auxiliaries = batch

assert len(prompts) == 4
assert all(isinstance(p, str) for p in prompts)
assert all(aux["Category"] == "Animals" for aux in auxiliaries)