-
Notifications
You must be signed in to change notification settings - Fork 80
feat: add benchmark support to PrunaDataModule and implement PartiPrompts #502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6db8f0b
7c53c95
975adb3
6b0f4f7
56f2167
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from functools import partial | ||
| from typing import Any, Callable, Tuple | ||
|
|
||
|
|
@@ -97,8 +98,164 @@ | |
| {"img_size": 224}, | ||
| ), | ||
| "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), | ||
| "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), | ||
| "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), | ||
| "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), | ||
| "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), | ||
| "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), | ||
| } | ||
|
|
||
|
|
||
| @dataclass | ||
| class BenchmarkInfo: | ||
| """ | ||
| Metadata for a benchmark dataset. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| name : str | ||
| Internal identifier for the benchmark. | ||
| display_name : str | ||
| Human-readable name for display purposes. | ||
| description : str | ||
| Description of what the benchmark evaluates. | ||
| metrics : list[str] | ||
| List of metric names used for evaluation. | ||
| task_type : str | ||
| Type of task the benchmark evaluates (e.g., 'text_to_image'). | ||
| subsets : list[str] | ||
| Optional list of benchmark subset names. | ||
| """ | ||
|
|
||
| name: str | ||
| display_name: str | ||
| description: str | ||
| metrics: list[str] | ||
| task_type: str | ||
| subsets: list[str] = field(default_factory=list) | ||
|
|
||
|
|
||
| benchmark_info: dict[str, BenchmarkInfo] = { | ||
| "PartiPrompts": BenchmarkInfo( | ||
| name="parti_prompts", | ||
| display_name="Parti Prompts", | ||
| description=( | ||
| "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " | ||
| "ranging from basic to complex, enabling comprehensive assessment of model capabilities " | ||
| "across different domains and difficulty levels." | ||
| ), | ||
| metrics=["arniqa", "clip_score", "clipiqa", "sharpness"], | ||
| task_type="text_to_image", | ||
| subsets=[ | ||
| "Abstract", | ||
| "Animals", | ||
| "Artifacts", | ||
| "Arts", | ||
| "Food & Beverage", | ||
| "Illustrations", | ||
| "Indoor Scenes", | ||
| "Outdoor Scenes", | ||
| "People", | ||
| "Produce & Plants", | ||
| "Vehicles", | ||
| "World Knowledge", | ||
| "Basic", | ||
| "Complex", | ||
| "Fine-grained Detail", | ||
| "Imagination", | ||
| "Linguistic Structures", | ||
| "Perspective", | ||
| "Properties & Positioning", | ||
| "Quantity", | ||
| "Simple Detail", | ||
| "Style & Format", | ||
| "Writing & Symbols", | ||
| ], | ||
| ), | ||
| "DrawBench": BenchmarkInfo( | ||
| name="drawbench", | ||
| display_name="DrawBench", | ||
| description="A comprehensive benchmark for evaluating text-to-image generation models.", | ||
| metrics=["clip_score", "clipiqa", "sharpness"], | ||
| task_type="text_to_image", | ||
| ), | ||
| "GenAIBench": BenchmarkInfo( | ||
| name="genai_bench", | ||
| display_name="GenAI Bench", | ||
| description="A benchmark for evaluating generative AI models.", | ||
| metrics=["clip_score", "clipiqa", "sharpness"], | ||
| task_type="text_to_image", | ||
| ), | ||
| "VBench": BenchmarkInfo( | ||
| name="vbench", | ||
| display_name="VBench", | ||
| description="A benchmark for evaluating video generation models.", | ||
| metrics=["clip_score"], | ||
| task_type="text_to_video", | ||
| ), | ||
| "COCO": BenchmarkInfo( | ||
| name="coco", | ||
| display_name="COCO", | ||
| description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", | ||
| metrics=["fid", "clip_score", "clipiqa"], | ||
| task_type="text_to_image", | ||
| ), | ||
| "ImageNet": BenchmarkInfo( | ||
| name="imagenet", | ||
| display_name="ImageNet", | ||
| description="Large-scale image classification benchmark with 1000 classes.", | ||
| metrics=["accuracy"], | ||
| task_type="image_classification", | ||
| ), | ||
| "WikiText": BenchmarkInfo( | ||
| name="wikitext", | ||
| display_name="WikiText", | ||
| description="Language modeling benchmark based on Wikipedia articles.", | ||
| metrics=["perplexity"], | ||
| task_type="text_generation", | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| def list_benchmarks(task_type: str | None = None) -> list[str]: | ||
| """ | ||
| List available benchmark names. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| task_type : str | None | ||
| Filter by task type (e.g., 'text_to_image', 'text_to_video'). | ||
| If None, returns all benchmarks. | ||
|
|
||
| Returns | ||
| ------- | ||
| list[str] | ||
| List of benchmark names. | ||
| """ | ||
| if task_type is None: | ||
| return list(benchmark_info.keys()) | ||
| return [name for name, info in benchmark_info.items() if info.task_type == task_type] | ||
|
|
||
|
|
||
| def get_benchmark_info(name: str) -> BenchmarkInfo: | ||
| """ | ||
| Get benchmark metadata by name. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| name : str | ||
| The benchmark name. | ||
|
|
||
| Returns | ||
| ------- | ||
| BenchmarkInfo | ||
| The benchmark metadata. | ||
|
|
||
| Raises | ||
| ------ | ||
| KeyError | ||
| If benchmark name is not found. | ||
| """ | ||
| if name not in benchmark_info: | ||
| available = ", ".join(benchmark_info.keys()) | ||
| raise KeyError(f"Benchmark '{name}' not found. Available: {available}") | ||
| return benchmark_info[name] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: | |
| return ds.select([0]), ds.select([0]), ds | ||
|
|
||
|
|
||
| def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: | ||
| def setup_parti_prompts_dataset( | ||
| seed: int, | ||
| category: str | None = None, | ||
| num_samples: int | None = None, | ||
| ) -> Tuple[Dataset, Dataset, Dataset]: | ||
| """ | ||
| Setup the Parti Prompts dataset. | ||
|
|
||
|
|
@@ -51,13 +55,30 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: | |
| ---------- | ||
| seed : int | ||
| The seed to use. | ||
| category : str | None | ||
| Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts, | ||
| Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People, | ||
| Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex, | ||
| Fine-grained Detail, Imagination, Linguistic Structures, Perspective, | ||
| Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols. | ||
| num_samples : int | None | ||
| Maximum number of samples to return. If None, returns all samples. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tuple[Dataset, Dataset, Dataset] | ||
| The Parti Prompts dataset. | ||
| The Parti Prompts dataset (dummy train, dummy val, test). | ||
| """ | ||
| ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] | ||
|
|
||
| if category is not None: | ||
| ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Category filter fails silently when list is passedMedium Severity The |
||
|
|
||
| ds = ds.shuffle(seed=seed) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we are only creating test set how do you feel about not shuffling the data? |
||
|
|
||
| if num_samples is not None: | ||
| ds = ds.select(range(min(num_samples, len(ds)))) | ||
|
|
||
| ds = ds.rename_column("Prompt", "text") | ||
| pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.") | ||
| return ds.select([0]), ds.select([0]), ds | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,10 @@ | ||
| from typing import Any, Callable | ||
|
|
||
| import pytest | ||
| from transformers import AutoTokenizer | ||
| from datasets import Dataset | ||
| from torch.utils.data import TensorDataset | ||
| import torch | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from pruna.data import BenchmarkInfo, benchmark_info | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from pruna.data.datasets.image import setup_imagenet_dataset | ||
| from pruna.data.pruna_datamodule import PrunaDataModule | ||
|
|
||
|
|
@@ -80,3 +80,19 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar | |
| assert labels.dtype == torch.int64 | ||
| # iterate through the dataloaders | ||
| iterate_dataloaders(datamodule) | ||
|
|
||
|
|
||
|
|
||
| @pytest.mark.slow | ||
| def test_parti_prompts_with_category_filter(): | ||
| """Test PartiPrompts loading with category filter.""" | ||
| dm = PrunaDataModule.from_string( | ||
| "PartiPrompts", category="Animals", dataloader_args={"batch_size": 4} | ||
| ) | ||
| dm.limit_datasets(10) | ||
| batch = next(iter(dm.test_dataloader())) | ||
| prompts, auxiliaries = batch | ||
|
|
||
| assert len(prompts) == 4 | ||
| assert all(isinstance(p, str) for p in prompts) | ||
| assert all(aux["Category"] == "Animals" for aux in auxiliaries) | ||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should tightly couple benchmarking datasets with metrics. I think benchmarks should have their datasets available as PrunaDataModules, and the metrics for the Benchmarks should be Pruna Metrics. This way we can give the user the flexibility to use whichever dataset with whichever metric they choose, how do you feel?