Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added [conflicted 2].coverage
Binary file not shown.
Binary file added [conflicted 3].coverage
Binary file not shown.
Binary file added [conflicted 4].coverage
Binary file not shown.
Binary file added [conflicted 5].coverage
Binary file not shown.
Binary file added [conflicted].coverage
Binary file not shown.
191 changes: 190 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Tuple

Expand All @@ -26,8 +27,11 @@
setup_mnist_dataset,
)
from pruna.data.datasets.prompt import (
setup_dpg_dataset,
setup_drawbench_dataset,
setup_genai_bench_dataset,
setup_oneig_alignment_dataset,
setup_oneig_text_rendering_dataset,
setup_parti_prompts_dataset,
)
from pruna.data.datasets.question_answering import setup_polyglot_dataset
Expand Down Expand Up @@ -97,8 +101,193 @@
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
"OneIGTextRendering": (setup_oneig_text_rendering_dataset, "prompt_with_auxiliaries_collate", {}),
"OneIGAlignment": (setup_oneig_alignment_dataset, "prompt_with_auxiliaries_collate", {}),
"DPG": (setup_dpg_dataset, "prompt_with_auxiliaries_collate", {}),
"TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}),
"VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}),
}


@dataclass
class BenchmarkInfo:
"""
Metadata for a benchmark dataset.

Parameters
----------
name : str
Internal identifier for the benchmark.
display_name : str
Human-readable name for display purposes.
description : str
Description of what the benchmark evaluates.
metrics : list[str]
List of metric names used for evaluation.
task_type : str
Type of task the benchmark evaluates (e.g., 'text_to_image').
subsets : list[str]
Optional list of benchmark subset names.
"""

name: str
display_name: str
description: str
metrics: list[str]
task_type: str
subsets: list[str] = field(default_factory=list)


benchmark_info: dict[str, BenchmarkInfo] = {
"PartiPrompts": BenchmarkInfo(
name="parti_prompts",
display_name="Parti Prompts",
description=(
"Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects "
"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
"across different domains and difficulty levels."
),
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
subsets=[
"Abstract",
"Animals",
"Artifacts",
"Arts",
"Food & Beverage",
"Illustrations",
"Indoor Scenes",
"Outdoor Scenes",
"People",
"Produce & Plants",
"Vehicles",
"World Knowledge",
"Basic",
"Complex",
"Fine-grained Detail",
"Imagination",
"Linguistic Structures",
"Perspective",
"Properties & Positioning",
"Quantity",
"Simple Detail",
"Style & Format",
"Writing & Symbols",
],
),
"DrawBench": BenchmarkInfo(
name="drawbench",
display_name="DrawBench",
description="A comprehensive benchmark for evaluating text-to-image generation models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"GenAIBench": BenchmarkInfo(
name="genai_bench",
display_name="GenAI Bench",
description="A benchmark for evaluating generative AI models.",
metrics=["clip_score", "clipiqa", "sharpness"],
task_type="text_to_image",
),
"VBench": BenchmarkInfo(
name="vbench",
display_name="VBench",
description="A benchmark for evaluating video generation models.",
metrics=["clip_score"],
task_type="text_to_video",
),
"COCO": BenchmarkInfo(
name="coco",
display_name="COCO",
description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.",
metrics=["fid", "clip_score", "clipiqa"],
task_type="text_to_image",
),
"ImageNet": BenchmarkInfo(
name="imagenet",
display_name="ImageNet",
description="Large-scale image classification benchmark with 1000 classes.",
metrics=["accuracy"],
task_type="image_classification",
),
"WikiText": BenchmarkInfo(
name="wikitext",
display_name="WikiText",
description="Language modeling benchmark based on Wikipedia articles.",
metrics=["perplexity"],
task_type="text_generation",
),
"OneIGTextRendering": BenchmarkInfo(
name="oneig_text_rendering",
display_name="OneIG Text Rendering",
description="Evaluates text rendering quality in generated images using OCR-based metrics.",
metrics=["accuracy"],
task_type="text_to_image",
),
"OneIGAlignment": BenchmarkInfo(
name="oneig_alignment",
display_name="OneIG Alignment",
description="Evaluates image-text alignment for anime, human, and object generation with VQA-based questions.",
metrics=["accuracy"],
task_type="text_to_image",
subsets=["Anime_Stylization", "Portrait", "General_Object"],
),
"DPG": BenchmarkInfo(
name="dpg",
display_name="DPG",
description=(
"Descriptive Prompt Generation benchmark for evaluating image understanding "
"across entity, attribute, relation, and global aspects."
),
metrics=["accuracy"],
task_type="text_to_image",
subsets=["entity", "attribute", "relation", "global", "other"],
),
}


def list_benchmarks(task_type: str | None = None) -> list[str]:
"""
List available benchmark names.

Parameters
----------
task_type : str | None
Filter by task type (e.g., 'text_to_image', 'text_to_video').
If None, returns all benchmarks.

Returns
-------
list[str]
List of benchmark names.
"""
if task_type is None:
return list(benchmark_info.keys())
return [name for name, info in benchmark_info.items() if info.task_type == task_type]


def get_benchmark_info(name: str) -> BenchmarkInfo:
"""
Get benchmark metadata by name.

Parameters
----------
name : str
The benchmark name.

Returns
-------
BenchmarkInfo
The benchmark metadata.

Raises
------
KeyError
If benchmark name is not found.
"""
if name not in benchmark_info:
available = ", ".join(benchmark_info.keys())
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
return benchmark_info[name]
Loading