Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added [conflicted 2].coverage
Binary file not shown.
Binary file added [conflicted 3].coverage
Binary file not shown.
Binary file added [conflicted 4].coverage
Binary file not shown.
Binary file added [conflicted 5].coverage
Binary file not shown.
Binary file added [conflicted 6].coverage
Binary file not shown.
Binary file added [conflicted].coverage
Binary file not shown.
172 changes: 171 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Tuple

Expand All @@ -28,6 +29,7 @@
from pruna.data.datasets.prompt import (
setup_drawbench_dataset,
setup_genai_bench_dataset,
setup_oneig_dataset,
setup_parti_prompts_dataset,
)
from pruna.data.datasets.question_answering import setup_polyglot_dataset
Expand Down Expand Up @@ -97,8 +99,176 @@
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"OneIG": (setup_oneig_dataset, "prompt_with_auxiliaries_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}


@dataclass
class BenchmarkInfo:
"""
Metadata for a benchmark dataset.

Parameters
----------
name : str
Internal identifier for the benchmark.
display_name : str
Human-readable name for display purposes.
description : str
Description of what the benchmark evaluates.
metrics : list[str]
List of metric names used for evaluation.
task_type : str
Type of task the benchmark evaluates (e.g., 'text_to_image').
subsets : list[str]
Optional list of benchmark subset names.
"""

name: str
display_name: str
description: str
metrics: list[str]
task_type: str
subsets: list[str] = field(default_factory=list)


benchmark_info: dict[str, BenchmarkInfo] = {
"PartiPrompts": BenchmarkInfo(
name="parti_prompts",
display_name="Parti Prompts",
description=(
"Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects "
"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
"across different domains and difficulty levels."
),
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
subsets=[
"Abstract",
"Animals",
"Artifacts",
"Arts",
"Food & Beverage",
"Illustrations",
"Indoor Scenes",
"Outdoor Scenes",
"People",
"Produce & Plants",
"Vehicles",
"World Knowledge",
"Basic",
"Complex",
"Fine-grained Detail",
"Imagination",
"Linguistic Structures",
"Perspective",
"Properties & Positioning",
"Quantity",
"Simple Detail",
"Style & Format",
"Writing & Symbols",
],
),
"DrawBench": BenchmarkInfo(
name="drawbench",
display_name="DrawBench",
description="A comprehensive benchmark for evaluating text-to-image generation models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"GenAIBench": BenchmarkInfo(
name="genai_bench",
display_name="GenAI Bench",
description="A benchmark for evaluating generative AI models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"VBench": BenchmarkInfo(
name="vbench",
display_name="VBench",
description="A benchmark for evaluating video generation models.",
metrics=["clip_score"],
task_type="text_to_video",
),
"COCO": BenchmarkInfo(
name="coco",
display_name="COCO",
description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.",
metrics=["fid", "clip_score", "clipiqa"],
task_type="text_to_image",
),
"ImageNet": BenchmarkInfo(
name="imagenet",
display_name="ImageNet",
description="Large-scale image classification benchmark with 1000 classes.",
metrics=["accuracy"],
task_type="image_classification",
),
"WikiText": BenchmarkInfo(
name="wikitext",
display_name="WikiText",
description="Language modeling benchmark based on Wikipedia articles.",
metrics=["perplexity"],
task_type="text_generation",
),
"OneIG": BenchmarkInfo(
name="oneig",
display_name="OneIG",
description=(
"Comprehensive benchmark for text rendering and image-text alignment "
"evaluation across anime, portrait, and object generation."
),
metrics=["accuracy"],
task_type="text_to_image",
subsets=["text_rendering", "anime_alignment", "portrait_alignment", "object_alignment"],
),
}


def list_benchmarks(task_type: str | None = None) -> list[str]:
"""
List available benchmark names.

Parameters
----------
task_type : str | None
Filter by task type (e.g., 'text_to_image', 'text_to_video').
If None, returns all benchmarks.

Returns
-------
list[str]
List of benchmark names.
"""
if task_type is None:
return list(benchmark_info.keys())
return [name for name, info in benchmark_info.items() if info.task_type == task_type]


def get_benchmark_info(name: str) -> BenchmarkInfo:
"""
Get benchmark metadata by name.

Parameters
----------
name : str
The benchmark name.

Returns
-------
BenchmarkInfo
The benchmark metadata.

Raises
------
KeyError
If benchmark name is not found.
"""
if name not in benchmark_info:
available = ", ".join(benchmark_info.keys())
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
return benchmark_info[name]
152 changes: 150 additions & 2 deletions src/pruna/data/datasets/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
return ds.select([0]), ds.select([0]), ds


def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_parti_prompts_dataset(
seed: int,
category: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Parti Prompts dataset.

Expand All @@ -51,13 +55,30 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
----------
seed : int
The seed to use.
category : str | None
Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts,
Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People,
Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex,
Fine-grained Detail, Imagination, Linguistic Structures, Perspective,
Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Parti Prompts dataset.
The Parti Prompts dataset (dummy train, dummy val, test).
"""
ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]

if category is not None:
ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category)

ds = ds.shuffle(seed=seed)

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

ds = ds.rename_column("Prompt", "text")
pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds
Expand All @@ -83,3 +104,130 @@ def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
ds = ds.rename_column("Prompt", "text")
pruna_logger.info("GenAI-Bench is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds


ONEIG_SUBSETS = ["text_rendering", "anime_alignment", "portrait_alignment", "object_alignment"]


def _load_oneig_text_rendering(seed: int) -> Dataset:
"""Load OneIG text rendering data from GitHub CSV."""
import csv
import io

import requests

url = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/main/benchmark/text_rendering.csv"
response = requests.get(url)
reader = csv.DictReader(io.StringIO(response.text))

records = []
for row in reader:
records.append(
{
"text": row.get("prompt", ""),
"subset": "text_rendering",
"text_content": row.get("text_content", row.get("text", "")),
"category": None,
"questions": [],
"dependencies": [],
}
)

return Dataset.from_list(records).shuffle(seed=seed)


def _load_oneig_alignment(seed: int, category: str | None = None) -> Dataset:
"""Load OneIG alignment data from HuggingFace + GitHub JSON."""
import json

import requests

category_map = {
"anime_alignment": "Anime_Stylization",
"portrait_alignment": "Portrait",
"object_alignment": "General_Object",
}

ds = load_dataset("OneIG-Bench/OneIG-Bench")["test"] # type: ignore[index]

url = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/main/benchmark/alignment_questions.json"
response = requests.get(url)
questions_data = json.loads(response.text)
questions_by_id = {q["id"]: q for q in questions_data}

records = []
for row in ds:
row_id = row.get("id", "")
row_category = row.get("category", "")

if category is not None:
target_category = category_map.get(category, category)
if row_category != target_category:
continue

subset_name = {v: k for k, v in category_map.items()}.get(row_category, "alignment")
q_info = questions_by_id.get(row_id, {})
records.append(
{
"text": row.get("prompt", ""),
"subset": subset_name,
"text_content": None,
"category": row_category,
"questions": q_info.get("questions", []),
"dependencies": q_info.get("dependencies", []),
}
)

return Dataset.from_list(records).shuffle(seed=seed)


def setup_oneig_dataset(
seed: int,
subset: str | None = None,
num_samples: int | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the OneIG benchmark dataset.

License: Apache 2.0

Parameters
----------
seed : int
The seed to use.
subset : str | None
Filter by subset. Available: text_rendering, anime_alignment, portrait_alignment,
object_alignment. If None, returns all subsets.
num_samples : int | None
Maximum number of samples to return. If None, returns all samples.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The OneIG dataset (dummy train, dummy val, test).
"""
from datasets import concatenate_datasets

if subset is not None and subset not in ONEIG_SUBSETS:
raise ValueError(f"Invalid subset: {subset}. Must be one of {ONEIG_SUBSETS}")

datasets_to_concat = []

if subset is None or subset == "text_rendering":
datasets_to_concat.append(_load_oneig_text_rendering(seed))

if subset is None or subset in ["anime_alignment", "portrait_alignment", "object_alignment"]:
alignment_subset = subset if subset in ["anime_alignment", "portrait_alignment", "object_alignment"] else None
datasets_to_concat.append(_load_oneig_alignment(seed, alignment_subset))

ds = concatenate_datasets(datasets_to_concat) if len(datasets_to_concat) > 1 else datasets_to_concat[0]
ds = ds.shuffle(seed=seed)

if num_samples is not None:
ds = ds.select(range(min(num_samples, len(ds))))

if len(ds) == 0:
raise ValueError(f"No samples found for subset '{subset}'. Check that the subset exists and has data.")

pruna_logger.info("OneIG is a test-only dataset. Do not use it for training or validation.")
return ds.select([0]), ds.select([0]), ds
7 changes: 6 additions & 1 deletion src/pruna/data/pruna_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def from_string(
dataloader_args: dict = dict(),
seed: int = 42,
category: str | list[str] | None = None,
subset: str | None = None,
) -> "PrunaDataModule":
"""
Create a PrunaDataModule from the dataset name with preimplemented dataset loading.
Expand All @@ -152,9 +153,10 @@ def from_string(
Any additional arguments for the dataloader.
seed : int
The seed to use.

category : str | list[str] | None
The category of the dataset.
subset : str | None
The subset of the dataset.

Returns
-------
Expand All @@ -173,6 +175,9 @@ def from_string(
if "category" in inspect.signature(setup_fn).parameters:
setup_fn = partial(setup_fn, category=category)

if "subset" in inspect.signature(setup_fn).parameters:
setup_fn = partial(setup_fn, subset=subset)

train_ds, val_ds, test_ds = setup_fn()

return cls.from_datasets(
Expand Down
Loading