From 6db8f0b51522f78a78feb3e4c407939b1e491d84 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 22 Jan 2026 10:58:01 +0100 Subject: [PATCH 01/10] feat: add benchmark support to PrunaDataModule and implement PartiPrompts benchmark - Introduced `from_benchmark` method in `PrunaDataModule` to create instances from benchmark classes. - Added `Benchmark`, `BenchmarkEntry`, and `BenchmarkRegistry` classes for managing benchmarks. - Implemented `PartiPrompts` benchmark for text-to-image generation with various categories and challenges. - Created utility function `benchmark_to_datasets` to convert benchmarks into datasets compatible with `PrunaDataModule`. - Added integration tests for benchmark functionality and data module interactions. --- src/pruna/data/pruna_datamodule.py | 54 +++ src/pruna/evaluation/benchmarks/__init__.py | 24 ++ src/pruna/evaluation/benchmarks/adapter.py | 70 ++++ src/pruna/evaluation/benchmarks/base.py | 86 +++++ src/pruna/evaluation/benchmarks/registry.py | 66 ++++ .../benchmarks/text_to_image/__init__.py | 67 ++++ .../benchmarks/text_to_image/parti.py | 316 ++++++++++++++++++ tests/evaluation/test_benchmarks.py | 158 +++++++++ .../evaluation/test_benchmarks_integration.py | 123 +++++++ 9 files changed, 964 insertions(+) create mode 100644 src/pruna/evaluation/benchmarks/__init__.py create mode 100644 src/pruna/evaluation/benchmarks/adapter.py create mode 100644 src/pruna/evaluation/benchmarks/base.py create mode 100644 src/pruna/evaluation/benchmarks/registry.py create mode 100644 src/pruna/evaluation/benchmarks/text_to_image/__init__.py create mode 100644 src/pruna/evaluation/benchmarks/text_to_image/parti.py create mode 100644 tests/evaluation/test_benchmarks.py create mode 100644 tests/evaluation/test_benchmarks_integration.py diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 435d7eec..30b47b29 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -25,6 +25,9 @@ from transformers.tokenization_utils import PreTrainedTokenizer as AutoTokenizer from pruna.data import base_datasets +from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets +from pruna.evaluation.benchmarks.base import Benchmark +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry from pruna.data.collate import pruna_collate_fns from pruna.data.utils import TokenizerMissingError from pruna.logging.logger import pruna_logger @@ -161,6 +164,13 @@ def from_string( PrunaDataModule The PrunaDataModule. """ + # Check if it's a benchmark first + benchmark_class = BenchmarkRegistry.get(dataset_name) + if benchmark_class is not None: + return cls.from_benchmark( + benchmark_class(seed=seed), tokenizer, collate_fn_args, dataloader_args + ) + setup_fn, collate_fn_name, default_collate_fn_args = base_datasets[dataset_name] # use default collate_fn_args and override with user-provided ones @@ -179,6 +189,50 @@ def from_string( (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args ) + @classmethod + def from_benchmark( + cls, + benchmark: Benchmark, + tokenizer: AutoTokenizer | None = None, + collate_fn_args: dict = dict(), + dataloader_args: dict = dict(), + ) -> "PrunaDataModule": + """ + Create a PrunaDataModule from a Benchmark instance. + + Parameters + ---------- + benchmark : Benchmark + The benchmark instance. + tokenizer : AutoTokenizer | None + The tokenizer to use (if needed for the task type). + collate_fn_args : dict + Any additional arguments for the collate function. + dataloader_args : dict + Any additional arguments for the dataloader. + + Returns + ------- + PrunaDataModule + The PrunaDataModule. + """ + train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) + + # Determine collate function based on task type + task_to_collate = { + "text_to_image": "prompt_collate", + "text_generation": "text_generation_collate", + "audio": "audio_collate", + "image_classification": "image_classification_collate", + "question_answering": "question_answering_collate", + } + + collate_fn_name = task_to_collate.get(benchmark.task_type, "prompt_collate") + + return cls.from_datasets( + (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args + ) + def limit_datasets(self, limit: int | list[int] | tuple[int, int, int]) -> None: """ Limit the dataset to the given number of samples. diff --git a/src/pruna/evaluation/benchmarks/__init__.py b/src/pruna/evaluation/benchmarks/__init__.py new file mode 100644 index 00000000..8a5c9df8 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry + +# Auto-register all benchmarks +from pruna.evaluation.benchmarks import text_to_image # noqa: F401 + +# Auto-register all benchmark subclasses +BenchmarkRegistry.auto_register_subclasses(text_to_image) + +__all__ = ["Benchmark", "BenchmarkEntry", "BenchmarkRegistry", "TASK"] diff --git a/src/pruna/evaluation/benchmarks/adapter.py b/src/pruna/evaluation/benchmarks/adapter.py new file mode 100644 index 00000000..b82bb273 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/adapter.py @@ -0,0 +1,70 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +from datasets import Dataset + +from pruna.evaluation.benchmarks.base import Benchmark +from pruna.logging.logger import pruna_logger + + +def benchmark_to_datasets(benchmark: Benchmark) -> Tuple[Dataset, Dataset, Dataset]: + """ + Convert a Benchmark instance to train/val/test datasets compatible with PrunaDataModule. + + Parameters + ---------- + benchmark : Benchmark + The benchmark instance to convert. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Train, validation, and test datasets. For test-only benchmarks, + train and val are dummy datasets with a single item. + """ + entries = list(benchmark) + + # Convert BenchmarkEntries to dict format expected by datasets + # For prompt-based benchmarks, we need "text" field for prompt_collate + data = [] + for entry in entries: + row = entry.model_inputs.copy() + row.update(entry.additional_info) + + # Ensure "text" field exists for prompt collate functions + if "text" not in row and "prompt" in row: + row["text"] = row["prompt"] + elif "text" not in row: + # If neither exists, use the first string value + for key, value in row.items(): + if isinstance(value, str): + row["text"] = value + break + + # Add path if needed for some collate functions + if "path" not in row: + row["path"] = entry.path + data.append(row) + + dataset = Dataset.from_list(data) + + # For test-only benchmarks (like PartiPrompts), create dummy train/val + pruna_logger.info(f"{benchmark.display_name} is a test-only dataset. Do not use it for training or validation.") + dummy = dataset.select([0]) if len(dataset) > 0 else dataset + + return dummy, dummy, dataset diff --git a/src/pruna/evaluation/benchmarks/base.py b/src/pruna/evaluation/benchmarks/base.py new file mode 100644 index 00000000..62837b03 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/base.py @@ -0,0 +1,86 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Iterator, List, Literal + +TASK = Literal[ + "text_to_image", + "text_generation", + "audio", + "image_classification", + "question_answering", +] + + +@dataclass +class BenchmarkEntry: + """A single entry in a benchmark dataset.""" + + model_inputs: dict[str, Any] + model_outputs: dict[str, Any] = field(default_factory=dict) + path: str = "" + additional_info: dict[str, Any] = field(default_factory=dict) + task_type: TASK = "text_to_image" + + +class Benchmark(ABC): + """Base class for all benchmark datasets.""" + + def __init__(self): + """Initialize the benchmark. Override to load data lazily or eagerly.""" + pass + + @abstractmethod + def __iter__(self) -> Iterator[BenchmarkEntry]: + """Iterate over benchmark entries.""" + pass + + @property + @abstractmethod + def name(self) -> str: + """Return the unique name identifier for this benchmark.""" + pass + + @property + @abstractmethod + def display_name(self) -> str: + """Return the human-readable display name for this benchmark.""" + pass + + @abstractmethod + def __len__(self) -> int: + """Return the number of items in the benchmark.""" + pass + + @property + @abstractmethod + def metrics(self) -> List[str]: + """Return the list of metric names recommended for this benchmark.""" + pass + + @property + @abstractmethod + def task_type(self) -> TASK: + """Return the task type for this benchmark.""" + pass + + @property + @abstractmethod + def description(self) -> str: + """Return a description of this benchmark.""" + pass diff --git a/src/pruna/evaluation/benchmarks/registry.py b/src/pruna/evaluation/benchmarks/registry.py new file mode 100644 index 00000000..2e6399ab --- /dev/null +++ b/src/pruna/evaluation/benchmarks/registry.py @@ -0,0 +1,66 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Type + +from pruna.evaluation.benchmarks.base import Benchmark + + +class BenchmarkRegistry: + """Registry for automatically discovering and registering benchmark classes.""" + + _registry: dict[str, Type[Benchmark]] = {} + + @classmethod + def register(cls, benchmark_class: Type[Benchmark]) -> Type[Benchmark]: + """Register a benchmark class by its name property.""" + # Create instance with default args to get the name + # This assumes benchmarks have default or no required arguments + try: + instance = benchmark_class() + name = instance.name + except Exception as e: + raise ValueError( + f"Failed to create instance of {benchmark_class.__name__} for registration: {e}. " + "Ensure the benchmark class can be instantiated with default arguments." + ) from e + + if name in cls._registry: + raise ValueError(f"Benchmark with name '{name}' is already registered.") + cls._registry[name] = benchmark_class + return benchmark_class + + @classmethod + def get(cls, name: str) -> Type[Benchmark] | None: + """Get a benchmark class by name.""" + return cls._registry.get(name) + + @classmethod + def list_all(cls) -> dict[str, Type[Benchmark]]: + """List all registered benchmarks.""" + return cls._registry.copy() + + @classmethod + def auto_register_subclasses(cls, module) -> None: + """Automatically register all Benchmark subclasses in a module.""" + for name, obj in inspect.getmembers(module, inspect.isclass): + if ( + issubclass(obj, Benchmark) + and obj is not Benchmark + and (obj.__module__ == module.__name__ or obj.__module__.startswith(module.__name__ + ".")) + ): + cls.register(obj) diff --git a/src/pruna/evaluation/benchmarks/text_to_image/__init__.py b/src/pruna/evaluation/benchmarks/text_to_image/__init__.py new file mode 100644 index 00000000..d5f38c69 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/text_to_image/__init__.py @@ -0,0 +1,67 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pruna.evaluation.benchmarks.text_to_image.parti import ( + PartiPrompts, + PartiPromptsAbstract, + PartiPromptsAnimals, + PartiPromptsArtifacts, + PartiPromptsArts, + PartiPromptsBasic, + PartiPromptsComplex, + PartiPromptsFineGrainedDetail, + PartiPromptsFoodBeverage, + PartiPromptsImagination, + PartiPromptsIllustrations, + PartiPromptsIndoorScenes, + PartiPromptsLinguisticStructures, + PartiPromptsOutdoorScenes, + PartiPromptsPeople, + PartiPromptsPerspective, + PartiPromptsProducePlants, + PartiPromptsPropertiesPositioning, + PartiPromptsQuantity, + PartiPromptsSimpleDetail, + PartiPromptsStyleFormat, + PartiPromptsVehicles, + PartiPromptsWorldKnowledge, + PartiPromptsWritingSymbols, +) + +__all__ = [ + "PartiPrompts", + "PartiPromptsAbstract", + "PartiPromptsAnimals", + "PartiPromptsArtifacts", + "PartiPromptsArts", + "PartiPromptsBasic", + "PartiPromptsComplex", + "PartiPromptsFineGrainedDetail", + "PartiPromptsFoodBeverage", + "PartiPromptsImagination", + "PartiPromptsIllustrations", + "PartiPromptsIndoorScenes", + "PartiPromptsLinguisticStructures", + "PartiPromptsOutdoorScenes", + "PartiPromptsPeople", + "PartiPromptsPerspective", + "PartiPromptsProducePlants", + "PartiPromptsPropertiesPositioning", + "PartiPromptsQuantity", + "PartiPromptsSimpleDetail", + "PartiPromptsStyleFormat", + "PartiPromptsVehicles", + "PartiPromptsWorldKnowledge", + "PartiPromptsWritingSymbols", +] diff --git a/src/pruna/evaluation/benchmarks/text_to_image/parti.py b/src/pruna/evaluation/benchmarks/text_to_image/parti.py new file mode 100644 index 00000000..089b64df --- /dev/null +++ b/src/pruna/evaluation/benchmarks/text_to_image/parti.py @@ -0,0 +1,316 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Iterator, List, cast + +from datasets import Dataset, load_dataset + +from pruna.evaluation.benchmarks.base import TASK, Benchmark, BenchmarkEntry + + +class PartiPrompts(Benchmark): + """Parti Prompts benchmark for text-to-image generation.""" + + def __init__( + self, + seed: int = 42, + num_samples: int | None = None, + subset: str | None = None, + ): + """ + Initialize the Parti Prompts benchmark. + + Parameters + ---------- + seed : int + Random seed for shuffling. Default is 42. + num_samples : int | None + Number of samples to select. If None, uses all samples. Default is None. + subset : str | None + Filter by a subset of the dataset. For PartiPrompts, this can be either: + + **Categories:** + - "Abstract" + - "Animals" + - "Artifacts" + - "Arts" + - "Food & Beverage" + - "Illustrations" + - "Indoor Scenes" + - "Outdoor Scenes" + - "People" + - "Produce & Plants" + - "Vehicles" + - "World Knowledge" + + **Challenges:** + - "Basic" + - "Complex" + - "Fine-grained Detail" + - "Imagination" + - "Linguistic Structures" + - "Perspective" + - "Properties & Positioning" + - "Quantity" + - "Simple Detail" + - "Style & Format" + - "Writing & Symbols" + + If None, includes all samples. Default is None. + """ + super().__init__() + self._seed = seed + self._num_samples = num_samples + + # Determine if subset refers to a dataset category or challenge + # Check against known challenges + self.subset = subset + + def _load_prompts(self) -> List[dict]: + """Load prompts from the dataset.""" + dataset_dict = load_dataset("nateraw/parti-prompts") # type: ignore + dataset = cast(Dataset, dataset_dict["train"]) # type: ignore + if self.subset is not None: + dataset = dataset.filter(lambda x: x["Category"] == self.subset or x["Challenge"] == self.subset) + shuffled_dataset = dataset.shuffle(seed=self._seed) + if self._num_samples is not None: + selected_dataset = shuffled_dataset.select(range(min(self._num_samples, len(shuffled_dataset)))) + else: + selected_dataset = shuffled_dataset + return list(selected_dataset) + + def __iter__(self) -> Iterator[BenchmarkEntry]: + """Iterate over benchmark entries.""" + for i, row in enumerate(self._load_prompts()): + yield BenchmarkEntry( + model_inputs={"prompt": row["Prompt"]}, + model_outputs={}, + path=f"{i}.png", + additional_info={ + "category": row["Category"], + "challenge": row["Challenge"], + "note": row.get("Note", ""), + }, + task_type=self.task_type, + ) + + @property + def name(self) -> str: + """Return the unique name identifier.""" + if self.subset is None: + return "parti_prompts" + normalized = ( + self.subset.lower().replace(" & ", "_").replace(" ", "_").replace("&", "_").replace("__", "_").rstrip("_") + ) + return f"parti_prompts_{normalized}" + + @property + def display_name(self) -> str: + """Return the human-readable display name.""" + if self.subset is None: + return "Parti Prompts" + return f"Parti Prompts ({self.subset})" + + def __len__(self) -> int: + """Return the number of entries in the benchmark.""" + return len(self._load_prompts()) + + @property + def metrics(self) -> List[str]: + """Return the list of recommended metrics.""" + return ["arniqa", "clip", "clip_iqa", "sharpness"] + + @property + def task_type(self) -> TASK: + """Return the task type.""" + return "text_to_image" + + @property + def description(self) -> str: + """Return a description of the benchmark.""" + return ( + "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " + "ranging from basic to complex, enabling comprehensive assessment of model capabilities " + "across different domains and difficulty levels." + ) + + +# Category-based subclasses +class PartiPromptsAbstract(PartiPrompts): + """Parti Prompts filtered by Abstract category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Abstract" if subset is None else subset) + + +class PartiPromptsAnimals(PartiPrompts): + """Parti Prompts filtered by Animals category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Animals" if subset is None else subset) + + +class PartiPromptsArtifacts(PartiPrompts): + """Parti Prompts filtered by Artifacts category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Artifacts" if subset is None else subset) + + +class PartiPromptsArts(PartiPrompts): + """Parti Prompts filtered by Arts category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Arts" if subset is None else subset) + + +class PartiPromptsFoodBeverage(PartiPrompts): + """Parti Prompts filtered by Food & Beverage category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Food & Beverage" if subset is None else subset) + + +class PartiPromptsIllustrations(PartiPrompts): + """Parti Prompts filtered by Illustrations category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Illustrations" if subset is None else subset) + + +class PartiPromptsIndoorScenes(PartiPrompts): + """Parti Prompts filtered by Indoor Scenes category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Indoor Scenes" if subset is None else subset) + + +class PartiPromptsOutdoorScenes(PartiPrompts): + """Parti Prompts filtered by Outdoor Scenes category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Outdoor Scenes" if subset is None else subset) + + +class PartiPromptsPeople(PartiPrompts): + """Parti Prompts filtered by People category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="People" if subset is None else subset) + + +class PartiPromptsProducePlants(PartiPrompts): + """Parti Prompts filtered by Produce & Plants category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Produce & Plants" if subset is None else subset) + + +class PartiPromptsVehicles(PartiPrompts): + """Parti Prompts filtered by Vehicles category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Vehicles" if subset is None else subset) + + +class PartiPromptsWorldKnowledge(PartiPrompts): + """Parti Prompts filtered by World Knowledge category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="World Knowledge" if subset is None else subset) + + +# Challenge-based subclasses +class PartiPromptsBasic(PartiPrompts): + """Parti Prompts filtered by Basic challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + # subset can be a category to further filter when challenge is already set + super().__init__(seed=seed, num_samples=num_samples, subset="Basic" if subset is None else subset) + + +class PartiPromptsComplex(PartiPrompts): + """Parti Prompts filtered by Complex challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Complex" if subset is None else subset) + + +class PartiPromptsFineGrainedDetail(PartiPrompts): + """Parti Prompts filtered by Fine-grained Detail challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Fine-grained Detail" if subset is None else subset) + + +class PartiPromptsImagination(PartiPrompts): + """Parti Prompts filtered by Imagination challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Imagination" if subset is None else subset) + + +class PartiPromptsLinguisticStructures(PartiPrompts): + """Parti Prompts filtered by Linguistic Structures challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__( + seed=seed, num_samples=num_samples, subset="Linguistic Structures" if subset is None else subset + ) + + +class PartiPromptsPerspective(PartiPrompts): + """Parti Prompts filtered by Perspective challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Perspective" if subset is None else subset) + + +class PartiPromptsPropertiesPositioning(PartiPrompts): + """Parti Prompts filtered by Properties & Positioning challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__( + seed=seed, num_samples=num_samples, subset="Properties & Positioning" if subset is None else subset + ) + + +class PartiPromptsQuantity(PartiPrompts): + """Parti Prompts filtered by Quantity challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Quantity" if subset is None else subset) + + +class PartiPromptsSimpleDetail(PartiPrompts): + """Parti Prompts filtered by Simple Detail challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Simple Detail" if subset is None else subset) + + +class PartiPromptsStyleFormat(PartiPrompts): + """Parti Prompts filtered by Style & Format challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Style & Format" if subset is None else subset) + + +class PartiPromptsWritingSymbols(PartiPrompts): + """Parti Prompts filtered by Writing & Symbols challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Writing & Symbols" if subset is None else subset) diff --git a/tests/evaluation/test_benchmarks.py b/tests/evaluation/test_benchmarks.py new file mode 100644 index 00000000..d9345874 --- /dev/null +++ b/tests/evaluation/test_benchmarks.py @@ -0,0 +1,158 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the benchmarks module.""" + +import pytest + +from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry +from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets +from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts + + +def test_benchmark_entry_creation(): + """Test creating a BenchmarkEntry with all fields.""" + entry = BenchmarkEntry( + model_inputs={"prompt": "test prompt"}, + model_outputs={"image": "test_image.png"}, + path="test.png", + additional_info={"category": "test"}, + task_type="text_to_image", + ) + + assert entry.model_inputs == {"prompt": "test prompt"} + assert entry.model_outputs == {"image": "test_image.png"} + assert entry.path == "test.png" + assert entry.additional_info == {"category": "test"} + assert entry.task_type == "text_to_image" + + +def test_benchmark_entry_defaults(): + """Test BenchmarkEntry with default values.""" + entry = BenchmarkEntry(model_inputs={"prompt": "test"}) + + assert entry.model_inputs == {"prompt": "test"} + assert entry.model_outputs == {} + assert entry.path == "" + assert entry.additional_info == {} + assert entry.task_type == "text_to_image" + + +def test_task_type_literal(): + """Test that TASK type only accepts valid task types.""" + # Valid task types + valid_tasks: list[TASK] = [ + "text_to_image", + "text_generation", + "audio", + "image_classification", + "question_answering", + ] + + for task in valid_tasks: + entry = BenchmarkEntry(model_inputs={}, task_type=task) + assert entry.task_type == task + + +def test_benchmark_registry_get(): + """Test getting a benchmark from the registry.""" + benchmark_class = BenchmarkRegistry.get("parti_prompts") + assert benchmark_class is not None + assert issubclass(benchmark_class, Benchmark) + + +def test_benchmark_registry_list_all(): + """Test listing all registered benchmarks.""" + all_benchmarks = BenchmarkRegistry.list_all() + assert isinstance(all_benchmarks, dict) + assert len(all_benchmarks) > 0 + assert "parti_prompts" in all_benchmarks + + +def test_benchmark_registry_get_nonexistent(): + """Test getting a non-existent benchmark returns None.""" + benchmark_class = BenchmarkRegistry.get("nonexistent_benchmark") + assert benchmark_class is None + + +def test_parti_prompts_creation(): + """Test creating a PartiPrompts benchmark instance.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + + assert benchmark.name == "parti_prompts" + assert benchmark.display_name == "Parti Prompts" + assert benchmark.task_type == "text_to_image" + assert len(benchmark.metrics) > 0 + assert isinstance(benchmark.description, str) + + +def test_parti_prompts_iteration(): + """Test iterating over PartiPrompts entries.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + entries = list(benchmark) + + assert len(entries) == 5 + for entry in entries: + assert isinstance(entry, BenchmarkEntry) + assert "prompt" in entry.model_inputs + assert entry.task_type == "text_to_image" + assert entry.model_outputs == {} + + +def test_parti_prompts_length(): + """Test PartiPrompts __len__ method.""" + benchmark = PartiPrompts(seed=42, num_samples=10) + assert len(benchmark) == 10 + + +def test_parti_prompts_subset(): + """Test PartiPrompts with a subset filter.""" + benchmark = PartiPrompts(seed=42, num_samples=5, subset="Animals") + + assert "animals" in benchmark.name.lower() + assert "Animals" in benchmark.display_name + + entries = list(benchmark) + for entry in entries: + assert entry.additional_info.get("category") == "Animals" + + +def test_benchmark_to_datasets(): + """Test converting a benchmark to datasets.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) + + assert len(test_ds) == 5 + assert len(train_ds) == 1 # Dummy dataset + assert len(val_ds) == 1 # Dummy dataset + + # Check that test dataset has the expected fields + sample = test_ds[0] + assert "prompt" in sample or "text" in sample + + +def test_benchmark_entry_task_type_validation(): + """Test that BenchmarkEntry validates task_type.""" + # This should work + entry = BenchmarkEntry( + model_inputs={}, + task_type="text_to_image", + ) + assert entry.task_type == "text_to_image" + + # Test other valid task types + for task in ["text_generation", "audio", "image_classification", "question_answering"]: + entry = BenchmarkEntry(model_inputs={}, task_type=task) + assert entry.task_type == task diff --git a/tests/evaluation/test_benchmarks_integration.py b/tests/evaluation/test_benchmarks_integration.py new file mode 100644 index 00000000..e7198f07 --- /dev/null +++ b/tests/evaluation/test_benchmarks_integration.py @@ -0,0 +1,123 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for benchmarks with datamodule and metrics.""" + +import pytest + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry +from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts +from pruna.evaluation.metrics.registry import MetricRegistry + + +@pytest.mark.cpu +def test_datamodule_from_benchmark(): + """Test creating a PrunaDataModule from a benchmark.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + datamodule = PrunaDataModule.from_benchmark(benchmark) + + assert datamodule is not None + assert datamodule.test_dataset is not None + assert len(datamodule.test_dataset) == 5 + + +@pytest.mark.cpu +def test_datamodule_from_benchmark_string(): + """Test creating a PrunaDataModule from a benchmark name string.""" + datamodule = PrunaDataModule.from_string("parti_prompts", seed=42) + + assert datamodule is not None + # Limit to small number for testing + datamodule.limit_datasets(5) + + # Test that we can iterate through the dataloader + test_loader = datamodule.test_dataloader(batch_size=2) + batch = next(iter(test_loader)) + assert batch is not None + + +@pytest.mark.cpu +def test_benchmark_with_metrics(): + """Test that benchmarks provide recommended metrics.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + recommended_metrics = benchmark.metrics + + assert isinstance(recommended_metrics, list) + assert len(recommended_metrics) > 0 + + # Check that metrics can be retrieved from registry + for metric_name in recommended_metrics: + # Some metrics might be registered, some might not + # Just verify the names are strings + assert isinstance(metric_name, str) + + +@pytest.mark.cpu +def test_benchmark_registry_integration(): + """Test that benchmarks are properly registered and can be used.""" + # Get benchmark from registry + benchmark_class = BenchmarkRegistry.get("parti_prompts") + assert benchmark_class is not None + + # Create instance + benchmark = benchmark_class(seed=42, num_samples=3) + + # Verify it works with datamodule + datamodule = PrunaDataModule.from_benchmark(benchmark) + assert datamodule is not None + + # Verify we can get entries + entries = list(benchmark) + assert len(entries) == 3 + + +@pytest.mark.cpu +def test_benchmark_task_type_mapping(): + """Test that benchmark task types map correctly to collate functions.""" + benchmark = PartiPrompts(seed=42, num_samples=3) + + # Create datamodule and verify it uses the correct collate function + datamodule = PrunaDataModule.from_benchmark(benchmark) + + # The collate function should be set based on task_type + assert datamodule.collate_fn is not None + + # Verify we can use the dataloader + test_loader = datamodule.test_dataloader(batch_size=1) + batch = next(iter(test_loader)) + assert batch is not None + + +@pytest.mark.cpu +def test_benchmark_entry_model_outputs(): + """Test that BenchmarkEntry can store model outputs.""" + from pruna.evaluation.benchmarks.base import BenchmarkEntry + + entry = BenchmarkEntry( + model_inputs={"prompt": "test"}, + model_outputs={"image": "generated_image.png", "score": 0.95}, + ) + + assert entry.model_outputs == {"image": "generated_image.png", "score": 0.95} + + # Verify entries from benchmark have empty model_outputs by default + benchmark = PartiPrompts(seed=42, num_samples=2) + entries = list(benchmark) + + for entry in entries: + assert entry.model_outputs == {} + # But model_outputs field exists and can be populated + entry.model_outputs["test"] = "value" + assert entry.model_outputs["test"] == "value" From 7c53c95a654528b4851ba8f93dbc7bb0aa7339d8 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 15:50:16 +0100 Subject: [PATCH 02/10] refactor: simplify benchmark system, extend PartiPrompts with subset filtering - Remove heavy benchmark abstraction (Benchmark class, registry, adapter, 24 subclasses) - Extend setup_parti_prompts_dataset with category and num_samples params - Add BenchmarkInfo dataclass for metadata (metrics, description, subsets) - Switch PartiPrompts to prompt_with_auxiliaries_collate to preserve Category/Challenge - Merge tests into test_datamodule.py Reduces 964 lines to 128 lines (87% reduction) Co-authored-by: Cursor --- src/pruna/data/__init__.py | 76 ++++- src/pruna/data/datasets/prompt.py | 25 +- src/pruna/data/pruna_datamodule.py | 54 --- src/pruna/evaluation/benchmarks/__init__.py | 24 -- src/pruna/evaluation/benchmarks/adapter.py | 70 ---- src/pruna/evaluation/benchmarks/base.py | 86 ----- src/pruna/evaluation/benchmarks/registry.py | 66 ---- .../benchmarks/text_to_image/__init__.py | 67 ---- .../benchmarks/text_to_image/parti.py | 316 ------------------ tests/data/test_datamodule.py | 22 +- tests/evaluation/test_benchmarks.py | 158 --------- .../evaluation/test_benchmarks_integration.py | 123 ------- 12 files changed, 117 insertions(+), 970 deletions(-) delete mode 100644 src/pruna/evaluation/benchmarks/__init__.py delete mode 100644 src/pruna/evaluation/benchmarks/adapter.py delete mode 100644 src/pruna/evaluation/benchmarks/base.py delete mode 100644 src/pruna/evaluation/benchmarks/registry.py delete mode 100644 src/pruna/evaluation/benchmarks/text_to_image/__init__.py delete mode 100644 src/pruna/evaluation/benchmarks/text_to_image/parti.py delete mode 100644 tests/evaluation/test_benchmarks.py delete mode 100644 tests/evaluation/test_benchmarks_integration.py diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 3a811868..820d1262 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass, field from functools import partial from typing import Any, Callable, Tuple @@ -97,8 +98,81 @@ {"img_size": 224}, ), "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), - "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), + "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } + + +@dataclass +class BenchmarkInfo: + """Metadata for a benchmark dataset.""" + + name: str + display_name: str + description: str + metrics: list[str] + task_type: str + subsets: list[str] = field(default_factory=list) + + +benchmark_info: dict[str, BenchmarkInfo] = { + "PartiPrompts": BenchmarkInfo( + name="parti_prompts", + display_name="Parti Prompts", + description=( + "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " + "ranging from basic to complex, enabling comprehensive assessment of model capabilities " + "across different domains and difficulty levels." + ), + metrics=["arniqa", "clip", "clip_iqa", "sharpness"], + task_type="text_to_image", + subsets=[ + "Abstract", + "Animals", + "Artifacts", + "Arts", + "Food & Beverage", + "Illustrations", + "Indoor Scenes", + "Outdoor Scenes", + "People", + "Produce & Plants", + "Vehicles", + "World Knowledge", + "Basic", + "Complex", + "Fine-grained Detail", + "Imagination", + "Linguistic Structures", + "Perspective", + "Properties & Positioning", + "Quantity", + "Simple Detail", + "Style & Format", + "Writing & Symbols", + ], + ), + "DrawBench": BenchmarkInfo( + name="drawbench", + display_name="DrawBench", + description="A comprehensive benchmark for evaluating text-to-image generation models.", + metrics=["clip", "clip_iqa", "sharpness"], + task_type="text_to_image", + ), + "GenAIBench": BenchmarkInfo( + name="genai_bench", + display_name="GenAI Bench", + description="A benchmark for evaluating generative AI models.", + metrics=["clip", "clip_iqa", "sharpness"], + task_type="text_to_image", + ), + "VBench": BenchmarkInfo( + name="vbench", + display_name="VBench", + description="A benchmark for evaluating video generation models.", + metrics=["clip", "fvd"], + task_type="text_to_video", + ), +} diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 1f6fab71..4f275675 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: return ds.select([0]), ds.select([0]), ds -def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def setup_parti_prompts_dataset( + seed: int, + category: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the Parti Prompts dataset. @@ -51,13 +55,30 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ---------- seed : int The seed to use. + category : str | None + Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts, + Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People, + Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex, + Fine-grained Detail, Imagination, Linguistic Structures, Perspective, + Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. Returns ------- Tuple[Dataset, Dataset, Dataset] - The Parti Prompts dataset. + The Parti Prompts dataset (dummy train, dummy val, test). """ ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] + + if category is not None: + ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category) + + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + ds = ds.rename_column("Prompt", "text") pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 30b47b29..435d7eec 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -25,9 +25,6 @@ from transformers.tokenization_utils import PreTrainedTokenizer as AutoTokenizer from pruna.data import base_datasets -from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets -from pruna.evaluation.benchmarks.base import Benchmark -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry from pruna.data.collate import pruna_collate_fns from pruna.data.utils import TokenizerMissingError from pruna.logging.logger import pruna_logger @@ -164,13 +161,6 @@ def from_string( PrunaDataModule The PrunaDataModule. """ - # Check if it's a benchmark first - benchmark_class = BenchmarkRegistry.get(dataset_name) - if benchmark_class is not None: - return cls.from_benchmark( - benchmark_class(seed=seed), tokenizer, collate_fn_args, dataloader_args - ) - setup_fn, collate_fn_name, default_collate_fn_args = base_datasets[dataset_name] # use default collate_fn_args and override with user-provided ones @@ -189,50 +179,6 @@ def from_string( (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args ) - @classmethod - def from_benchmark( - cls, - benchmark: Benchmark, - tokenizer: AutoTokenizer | None = None, - collate_fn_args: dict = dict(), - dataloader_args: dict = dict(), - ) -> "PrunaDataModule": - """ - Create a PrunaDataModule from a Benchmark instance. - - Parameters - ---------- - benchmark : Benchmark - The benchmark instance. - tokenizer : AutoTokenizer | None - The tokenizer to use (if needed for the task type). - collate_fn_args : dict - Any additional arguments for the collate function. - dataloader_args : dict - Any additional arguments for the dataloader. - - Returns - ------- - PrunaDataModule - The PrunaDataModule. - """ - train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) - - # Determine collate function based on task type - task_to_collate = { - "text_to_image": "prompt_collate", - "text_generation": "text_generation_collate", - "audio": "audio_collate", - "image_classification": "image_classification_collate", - "question_answering": "question_answering_collate", - } - - collate_fn_name = task_to_collate.get(benchmark.task_type, "prompt_collate") - - return cls.from_datasets( - (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args - ) - def limit_datasets(self, limit: int | list[int] | tuple[int, int, int]) -> None: """ Limit the dataset to the given number of samples. diff --git a/src/pruna/evaluation/benchmarks/__init__.py b/src/pruna/evaluation/benchmarks/__init__.py deleted file mode 100644 index 8a5c9df8..00000000 --- a/src/pruna/evaluation/benchmarks/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry - -# Auto-register all benchmarks -from pruna.evaluation.benchmarks import text_to_image # noqa: F401 - -# Auto-register all benchmark subclasses -BenchmarkRegistry.auto_register_subclasses(text_to_image) - -__all__ = ["Benchmark", "BenchmarkEntry", "BenchmarkRegistry", "TASK"] diff --git a/src/pruna/evaluation/benchmarks/adapter.py b/src/pruna/evaluation/benchmarks/adapter.py deleted file mode 100644 index b82bb273..00000000 --- a/src/pruna/evaluation/benchmarks/adapter.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Tuple - -from datasets import Dataset - -from pruna.evaluation.benchmarks.base import Benchmark -from pruna.logging.logger import pruna_logger - - -def benchmark_to_datasets(benchmark: Benchmark) -> Tuple[Dataset, Dataset, Dataset]: - """ - Convert a Benchmark instance to train/val/test datasets compatible with PrunaDataModule. - - Parameters - ---------- - benchmark : Benchmark - The benchmark instance to convert. - - Returns - ------- - Tuple[Dataset, Dataset, Dataset] - Train, validation, and test datasets. For test-only benchmarks, - train and val are dummy datasets with a single item. - """ - entries = list(benchmark) - - # Convert BenchmarkEntries to dict format expected by datasets - # For prompt-based benchmarks, we need "text" field for prompt_collate - data = [] - for entry in entries: - row = entry.model_inputs.copy() - row.update(entry.additional_info) - - # Ensure "text" field exists for prompt collate functions - if "text" not in row and "prompt" in row: - row["text"] = row["prompt"] - elif "text" not in row: - # If neither exists, use the first string value - for key, value in row.items(): - if isinstance(value, str): - row["text"] = value - break - - # Add path if needed for some collate functions - if "path" not in row: - row["path"] = entry.path - data.append(row) - - dataset = Dataset.from_list(data) - - # For test-only benchmarks (like PartiPrompts), create dummy train/val - pruna_logger.info(f"{benchmark.display_name} is a test-only dataset. Do not use it for training or validation.") - dummy = dataset.select([0]) if len(dataset) > 0 else dataset - - return dummy, dummy, dataset diff --git a/src/pruna/evaluation/benchmarks/base.py b/src/pruna/evaluation/benchmarks/base.py deleted file mode 100644 index 62837b03..00000000 --- a/src/pruna/evaluation/benchmarks/base.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Iterator, List, Literal - -TASK = Literal[ - "text_to_image", - "text_generation", - "audio", - "image_classification", - "question_answering", -] - - -@dataclass -class BenchmarkEntry: - """A single entry in a benchmark dataset.""" - - model_inputs: dict[str, Any] - model_outputs: dict[str, Any] = field(default_factory=dict) - path: str = "" - additional_info: dict[str, Any] = field(default_factory=dict) - task_type: TASK = "text_to_image" - - -class Benchmark(ABC): - """Base class for all benchmark datasets.""" - - def __init__(self): - """Initialize the benchmark. Override to load data lazily or eagerly.""" - pass - - @abstractmethod - def __iter__(self) -> Iterator[BenchmarkEntry]: - """Iterate over benchmark entries.""" - pass - - @property - @abstractmethod - def name(self) -> str: - """Return the unique name identifier for this benchmark.""" - pass - - @property - @abstractmethod - def display_name(self) -> str: - """Return the human-readable display name for this benchmark.""" - pass - - @abstractmethod - def __len__(self) -> int: - """Return the number of items in the benchmark.""" - pass - - @property - @abstractmethod - def metrics(self) -> List[str]: - """Return the list of metric names recommended for this benchmark.""" - pass - - @property - @abstractmethod - def task_type(self) -> TASK: - """Return the task type for this benchmark.""" - pass - - @property - @abstractmethod - def description(self) -> str: - """Return a description of this benchmark.""" - pass diff --git a/src/pruna/evaluation/benchmarks/registry.py b/src/pruna/evaluation/benchmarks/registry.py deleted file mode 100644 index 2e6399ab..00000000 --- a/src/pruna/evaluation/benchmarks/registry.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import inspect -from typing import Type - -from pruna.evaluation.benchmarks.base import Benchmark - - -class BenchmarkRegistry: - """Registry for automatically discovering and registering benchmark classes.""" - - _registry: dict[str, Type[Benchmark]] = {} - - @classmethod - def register(cls, benchmark_class: Type[Benchmark]) -> Type[Benchmark]: - """Register a benchmark class by its name property.""" - # Create instance with default args to get the name - # This assumes benchmarks have default or no required arguments - try: - instance = benchmark_class() - name = instance.name - except Exception as e: - raise ValueError( - f"Failed to create instance of {benchmark_class.__name__} for registration: {e}. " - "Ensure the benchmark class can be instantiated with default arguments." - ) from e - - if name in cls._registry: - raise ValueError(f"Benchmark with name '{name}' is already registered.") - cls._registry[name] = benchmark_class - return benchmark_class - - @classmethod - def get(cls, name: str) -> Type[Benchmark] | None: - """Get a benchmark class by name.""" - return cls._registry.get(name) - - @classmethod - def list_all(cls) -> dict[str, Type[Benchmark]]: - """List all registered benchmarks.""" - return cls._registry.copy() - - @classmethod - def auto_register_subclasses(cls, module) -> None: - """Automatically register all Benchmark subclasses in a module.""" - for name, obj in inspect.getmembers(module, inspect.isclass): - if ( - issubclass(obj, Benchmark) - and obj is not Benchmark - and (obj.__module__ == module.__name__ or obj.__module__.startswith(module.__name__ + ".")) - ): - cls.register(obj) diff --git a/src/pruna/evaluation/benchmarks/text_to_image/__init__.py b/src/pruna/evaluation/benchmarks/text_to_image/__init__.py deleted file mode 100644 index d5f38c69..00000000 --- a/src/pruna/evaluation/benchmarks/text_to_image/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pruna.evaluation.benchmarks.text_to_image.parti import ( - PartiPrompts, - PartiPromptsAbstract, - PartiPromptsAnimals, - PartiPromptsArtifacts, - PartiPromptsArts, - PartiPromptsBasic, - PartiPromptsComplex, - PartiPromptsFineGrainedDetail, - PartiPromptsFoodBeverage, - PartiPromptsImagination, - PartiPromptsIllustrations, - PartiPromptsIndoorScenes, - PartiPromptsLinguisticStructures, - PartiPromptsOutdoorScenes, - PartiPromptsPeople, - PartiPromptsPerspective, - PartiPromptsProducePlants, - PartiPromptsPropertiesPositioning, - PartiPromptsQuantity, - PartiPromptsSimpleDetail, - PartiPromptsStyleFormat, - PartiPromptsVehicles, - PartiPromptsWorldKnowledge, - PartiPromptsWritingSymbols, -) - -__all__ = [ - "PartiPrompts", - "PartiPromptsAbstract", - "PartiPromptsAnimals", - "PartiPromptsArtifacts", - "PartiPromptsArts", - "PartiPromptsBasic", - "PartiPromptsComplex", - "PartiPromptsFineGrainedDetail", - "PartiPromptsFoodBeverage", - "PartiPromptsImagination", - "PartiPromptsIllustrations", - "PartiPromptsIndoorScenes", - "PartiPromptsLinguisticStructures", - "PartiPromptsOutdoorScenes", - "PartiPromptsPeople", - "PartiPromptsPerspective", - "PartiPromptsProducePlants", - "PartiPromptsPropertiesPositioning", - "PartiPromptsQuantity", - "PartiPromptsSimpleDetail", - "PartiPromptsStyleFormat", - "PartiPromptsVehicles", - "PartiPromptsWorldKnowledge", - "PartiPromptsWritingSymbols", -] diff --git a/src/pruna/evaluation/benchmarks/text_to_image/parti.py b/src/pruna/evaluation/benchmarks/text_to_image/parti.py deleted file mode 100644 index 089b64df..00000000 --- a/src/pruna/evaluation/benchmarks/text_to_image/parti.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Iterator, List, cast - -from datasets import Dataset, load_dataset - -from pruna.evaluation.benchmarks.base import TASK, Benchmark, BenchmarkEntry - - -class PartiPrompts(Benchmark): - """Parti Prompts benchmark for text-to-image generation.""" - - def __init__( - self, - seed: int = 42, - num_samples: int | None = None, - subset: str | None = None, - ): - """ - Initialize the Parti Prompts benchmark. - - Parameters - ---------- - seed : int - Random seed for shuffling. Default is 42. - num_samples : int | None - Number of samples to select. If None, uses all samples. Default is None. - subset : str | None - Filter by a subset of the dataset. For PartiPrompts, this can be either: - - **Categories:** - - "Abstract" - - "Animals" - - "Artifacts" - - "Arts" - - "Food & Beverage" - - "Illustrations" - - "Indoor Scenes" - - "Outdoor Scenes" - - "People" - - "Produce & Plants" - - "Vehicles" - - "World Knowledge" - - **Challenges:** - - "Basic" - - "Complex" - - "Fine-grained Detail" - - "Imagination" - - "Linguistic Structures" - - "Perspective" - - "Properties & Positioning" - - "Quantity" - - "Simple Detail" - - "Style & Format" - - "Writing & Symbols" - - If None, includes all samples. Default is None. - """ - super().__init__() - self._seed = seed - self._num_samples = num_samples - - # Determine if subset refers to a dataset category or challenge - # Check against known challenges - self.subset = subset - - def _load_prompts(self) -> List[dict]: - """Load prompts from the dataset.""" - dataset_dict = load_dataset("nateraw/parti-prompts") # type: ignore - dataset = cast(Dataset, dataset_dict["train"]) # type: ignore - if self.subset is not None: - dataset = dataset.filter(lambda x: x["Category"] == self.subset or x["Challenge"] == self.subset) - shuffled_dataset = dataset.shuffle(seed=self._seed) - if self._num_samples is not None: - selected_dataset = shuffled_dataset.select(range(min(self._num_samples, len(shuffled_dataset)))) - else: - selected_dataset = shuffled_dataset - return list(selected_dataset) - - def __iter__(self) -> Iterator[BenchmarkEntry]: - """Iterate over benchmark entries.""" - for i, row in enumerate(self._load_prompts()): - yield BenchmarkEntry( - model_inputs={"prompt": row["Prompt"]}, - model_outputs={}, - path=f"{i}.png", - additional_info={ - "category": row["Category"], - "challenge": row["Challenge"], - "note": row.get("Note", ""), - }, - task_type=self.task_type, - ) - - @property - def name(self) -> str: - """Return the unique name identifier.""" - if self.subset is None: - return "parti_prompts" - normalized = ( - self.subset.lower().replace(" & ", "_").replace(" ", "_").replace("&", "_").replace("__", "_").rstrip("_") - ) - return f"parti_prompts_{normalized}" - - @property - def display_name(self) -> str: - """Return the human-readable display name.""" - if self.subset is None: - return "Parti Prompts" - return f"Parti Prompts ({self.subset})" - - def __len__(self) -> int: - """Return the number of entries in the benchmark.""" - return len(self._load_prompts()) - - @property - def metrics(self) -> List[str]: - """Return the list of recommended metrics.""" - return ["arniqa", "clip", "clip_iqa", "sharpness"] - - @property - def task_type(self) -> TASK: - """Return the task type.""" - return "text_to_image" - - @property - def description(self) -> str: - """Return a description of the benchmark.""" - return ( - "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " - "ranging from basic to complex, enabling comprehensive assessment of model capabilities " - "across different domains and difficulty levels." - ) - - -# Category-based subclasses -class PartiPromptsAbstract(PartiPrompts): - """Parti Prompts filtered by Abstract category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Abstract" if subset is None else subset) - - -class PartiPromptsAnimals(PartiPrompts): - """Parti Prompts filtered by Animals category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Animals" if subset is None else subset) - - -class PartiPromptsArtifacts(PartiPrompts): - """Parti Prompts filtered by Artifacts category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Artifacts" if subset is None else subset) - - -class PartiPromptsArts(PartiPrompts): - """Parti Prompts filtered by Arts category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Arts" if subset is None else subset) - - -class PartiPromptsFoodBeverage(PartiPrompts): - """Parti Prompts filtered by Food & Beverage category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Food & Beverage" if subset is None else subset) - - -class PartiPromptsIllustrations(PartiPrompts): - """Parti Prompts filtered by Illustrations category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Illustrations" if subset is None else subset) - - -class PartiPromptsIndoorScenes(PartiPrompts): - """Parti Prompts filtered by Indoor Scenes category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Indoor Scenes" if subset is None else subset) - - -class PartiPromptsOutdoorScenes(PartiPrompts): - """Parti Prompts filtered by Outdoor Scenes category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Outdoor Scenes" if subset is None else subset) - - -class PartiPromptsPeople(PartiPrompts): - """Parti Prompts filtered by People category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="People" if subset is None else subset) - - -class PartiPromptsProducePlants(PartiPrompts): - """Parti Prompts filtered by Produce & Plants category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Produce & Plants" if subset is None else subset) - - -class PartiPromptsVehicles(PartiPrompts): - """Parti Prompts filtered by Vehicles category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Vehicles" if subset is None else subset) - - -class PartiPromptsWorldKnowledge(PartiPrompts): - """Parti Prompts filtered by World Knowledge category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="World Knowledge" if subset is None else subset) - - -# Challenge-based subclasses -class PartiPromptsBasic(PartiPrompts): - """Parti Prompts filtered by Basic challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - # subset can be a category to further filter when challenge is already set - super().__init__(seed=seed, num_samples=num_samples, subset="Basic" if subset is None else subset) - - -class PartiPromptsComplex(PartiPrompts): - """Parti Prompts filtered by Complex challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Complex" if subset is None else subset) - - -class PartiPromptsFineGrainedDetail(PartiPrompts): - """Parti Prompts filtered by Fine-grained Detail challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Fine-grained Detail" if subset is None else subset) - - -class PartiPromptsImagination(PartiPrompts): - """Parti Prompts filtered by Imagination challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Imagination" if subset is None else subset) - - -class PartiPromptsLinguisticStructures(PartiPrompts): - """Parti Prompts filtered by Linguistic Structures challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__( - seed=seed, num_samples=num_samples, subset="Linguistic Structures" if subset is None else subset - ) - - -class PartiPromptsPerspective(PartiPrompts): - """Parti Prompts filtered by Perspective challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Perspective" if subset is None else subset) - - -class PartiPromptsPropertiesPositioning(PartiPrompts): - """Parti Prompts filtered by Properties & Positioning challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__( - seed=seed, num_samples=num_samples, subset="Properties & Positioning" if subset is None else subset - ) - - -class PartiPromptsQuantity(PartiPrompts): - """Parti Prompts filtered by Quantity challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Quantity" if subset is None else subset) - - -class PartiPromptsSimpleDetail(PartiPrompts): - """Parti Prompts filtered by Simple Detail challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Simple Detail" if subset is None else subset) - - -class PartiPromptsStyleFormat(PartiPrompts): - """Parti Prompts filtered by Style & Format challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Style & Format" if subset is None else subset) - - -class PartiPromptsWritingSymbols(PartiPrompts): - """Parti Prompts filtered by Writing & Symbols challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Writing & Symbols" if subset is None else subset) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 04d226f6..61550698 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,10 +1,10 @@ from typing import Any, Callable import pytest -from transformers import AutoTokenizer -from datasets import Dataset -from torch.utils.data import TensorDataset import torch +from transformers import AutoTokenizer + +from pruna.data import BenchmarkInfo, benchmark_info from pruna.data.datasets.image import setup_imagenet_dataset from pruna.data.pruna_datamodule import PrunaDataModule @@ -80,3 +80,19 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar assert labels.dtype == torch.int64 # iterate through the dataloaders iterate_dataloaders(datamodule) + + + +@pytest.mark.slow +def test_parti_prompts_with_category_filter(): + """Test PartiPrompts loading with category filter.""" + dm = PrunaDataModule.from_string( + "PartiPrompts", category="Animals", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["Category"] == "Animals" for aux in auxiliaries) diff --git a/tests/evaluation/test_benchmarks.py b/tests/evaluation/test_benchmarks.py deleted file mode 100644 index d9345874..00000000 --- a/tests/evaluation/test_benchmarks.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for the benchmarks module.""" - -import pytest - -from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry -from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets -from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts - - -def test_benchmark_entry_creation(): - """Test creating a BenchmarkEntry with all fields.""" - entry = BenchmarkEntry( - model_inputs={"prompt": "test prompt"}, - model_outputs={"image": "test_image.png"}, - path="test.png", - additional_info={"category": "test"}, - task_type="text_to_image", - ) - - assert entry.model_inputs == {"prompt": "test prompt"} - assert entry.model_outputs == {"image": "test_image.png"} - assert entry.path == "test.png" - assert entry.additional_info == {"category": "test"} - assert entry.task_type == "text_to_image" - - -def test_benchmark_entry_defaults(): - """Test BenchmarkEntry with default values.""" - entry = BenchmarkEntry(model_inputs={"prompt": "test"}) - - assert entry.model_inputs == {"prompt": "test"} - assert entry.model_outputs == {} - assert entry.path == "" - assert entry.additional_info == {} - assert entry.task_type == "text_to_image" - - -def test_task_type_literal(): - """Test that TASK type only accepts valid task types.""" - # Valid task types - valid_tasks: list[TASK] = [ - "text_to_image", - "text_generation", - "audio", - "image_classification", - "question_answering", - ] - - for task in valid_tasks: - entry = BenchmarkEntry(model_inputs={}, task_type=task) - assert entry.task_type == task - - -def test_benchmark_registry_get(): - """Test getting a benchmark from the registry.""" - benchmark_class = BenchmarkRegistry.get("parti_prompts") - assert benchmark_class is not None - assert issubclass(benchmark_class, Benchmark) - - -def test_benchmark_registry_list_all(): - """Test listing all registered benchmarks.""" - all_benchmarks = BenchmarkRegistry.list_all() - assert isinstance(all_benchmarks, dict) - assert len(all_benchmarks) > 0 - assert "parti_prompts" in all_benchmarks - - -def test_benchmark_registry_get_nonexistent(): - """Test getting a non-existent benchmark returns None.""" - benchmark_class = BenchmarkRegistry.get("nonexistent_benchmark") - assert benchmark_class is None - - -def test_parti_prompts_creation(): - """Test creating a PartiPrompts benchmark instance.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - - assert benchmark.name == "parti_prompts" - assert benchmark.display_name == "Parti Prompts" - assert benchmark.task_type == "text_to_image" - assert len(benchmark.metrics) > 0 - assert isinstance(benchmark.description, str) - - -def test_parti_prompts_iteration(): - """Test iterating over PartiPrompts entries.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - entries = list(benchmark) - - assert len(entries) == 5 - for entry in entries: - assert isinstance(entry, BenchmarkEntry) - assert "prompt" in entry.model_inputs - assert entry.task_type == "text_to_image" - assert entry.model_outputs == {} - - -def test_parti_prompts_length(): - """Test PartiPrompts __len__ method.""" - benchmark = PartiPrompts(seed=42, num_samples=10) - assert len(benchmark) == 10 - - -def test_parti_prompts_subset(): - """Test PartiPrompts with a subset filter.""" - benchmark = PartiPrompts(seed=42, num_samples=5, subset="Animals") - - assert "animals" in benchmark.name.lower() - assert "Animals" in benchmark.display_name - - entries = list(benchmark) - for entry in entries: - assert entry.additional_info.get("category") == "Animals" - - -def test_benchmark_to_datasets(): - """Test converting a benchmark to datasets.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) - - assert len(test_ds) == 5 - assert len(train_ds) == 1 # Dummy dataset - assert len(val_ds) == 1 # Dummy dataset - - # Check that test dataset has the expected fields - sample = test_ds[0] - assert "prompt" in sample or "text" in sample - - -def test_benchmark_entry_task_type_validation(): - """Test that BenchmarkEntry validates task_type.""" - # This should work - entry = BenchmarkEntry( - model_inputs={}, - task_type="text_to_image", - ) - assert entry.task_type == "text_to_image" - - # Test other valid task types - for task in ["text_generation", "audio", "image_classification", "question_answering"]: - entry = BenchmarkEntry(model_inputs={}, task_type=task) - assert entry.task_type == task diff --git a/tests/evaluation/test_benchmarks_integration.py b/tests/evaluation/test_benchmarks_integration.py deleted file mode 100644 index e7198f07..00000000 --- a/tests/evaluation/test_benchmarks_integration.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Integration tests for benchmarks with datamodule and metrics.""" - -import pytest - -from pruna.data.pruna_datamodule import PrunaDataModule -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry -from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts -from pruna.evaluation.metrics.registry import MetricRegistry - - -@pytest.mark.cpu -def test_datamodule_from_benchmark(): - """Test creating a PrunaDataModule from a benchmark.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - datamodule = PrunaDataModule.from_benchmark(benchmark) - - assert datamodule is not None - assert datamodule.test_dataset is not None - assert len(datamodule.test_dataset) == 5 - - -@pytest.mark.cpu -def test_datamodule_from_benchmark_string(): - """Test creating a PrunaDataModule from a benchmark name string.""" - datamodule = PrunaDataModule.from_string("parti_prompts", seed=42) - - assert datamodule is not None - # Limit to small number for testing - datamodule.limit_datasets(5) - - # Test that we can iterate through the dataloader - test_loader = datamodule.test_dataloader(batch_size=2) - batch = next(iter(test_loader)) - assert batch is not None - - -@pytest.mark.cpu -def test_benchmark_with_metrics(): - """Test that benchmarks provide recommended metrics.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - recommended_metrics = benchmark.metrics - - assert isinstance(recommended_metrics, list) - assert len(recommended_metrics) > 0 - - # Check that metrics can be retrieved from registry - for metric_name in recommended_metrics: - # Some metrics might be registered, some might not - # Just verify the names are strings - assert isinstance(metric_name, str) - - -@pytest.mark.cpu -def test_benchmark_registry_integration(): - """Test that benchmarks are properly registered and can be used.""" - # Get benchmark from registry - benchmark_class = BenchmarkRegistry.get("parti_prompts") - assert benchmark_class is not None - - # Create instance - benchmark = benchmark_class(seed=42, num_samples=3) - - # Verify it works with datamodule - datamodule = PrunaDataModule.from_benchmark(benchmark) - assert datamodule is not None - - # Verify we can get entries - entries = list(benchmark) - assert len(entries) == 3 - - -@pytest.mark.cpu -def test_benchmark_task_type_mapping(): - """Test that benchmark task types map correctly to collate functions.""" - benchmark = PartiPrompts(seed=42, num_samples=3) - - # Create datamodule and verify it uses the correct collate function - datamodule = PrunaDataModule.from_benchmark(benchmark) - - # The collate function should be set based on task_type - assert datamodule.collate_fn is not None - - # Verify we can use the dataloader - test_loader = datamodule.test_dataloader(batch_size=1) - batch = next(iter(test_loader)) - assert batch is not None - - -@pytest.mark.cpu -def test_benchmark_entry_model_outputs(): - """Test that BenchmarkEntry can store model outputs.""" - from pruna.evaluation.benchmarks.base import BenchmarkEntry - - entry = BenchmarkEntry( - model_inputs={"prompt": "test"}, - model_outputs={"image": "generated_image.png", "score": 0.95}, - ) - - assert entry.model_outputs == {"image": "generated_image.png", "score": 0.95} - - # Verify entries from benchmark have empty model_outputs by default - benchmark = PartiPrompts(seed=42, num_samples=2) - entries = list(benchmark) - - for entry in entries: - assert entry.model_outputs == {} - # But model_outputs field exists and can be populated - entry.model_outputs["test"] = "value" - assert entry.model_outputs["test"] == "value" From 975adb3fdb3845eb363367902162a7b7269edd5a Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:18:22 +0100 Subject: [PATCH 03/10] fix: add Numpydoc parameter docs for BenchmarkInfo Document all dataclass fields per Numpydoc PR01 with summary on new line per GL01. Co-authored-by: Cursor --- src/pruna/data/__init__.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 820d1262..86e36cd4 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -107,7 +107,24 @@ @dataclass class BenchmarkInfo: - """Metadata for a benchmark dataset.""" + """ + Metadata for a benchmark dataset. + + Parameters + ---------- + name : str + Internal identifier for the benchmark. + display_name : str + Human-readable name for display purposes. + description : str + Description of what the benchmark evaluates. + metrics : list[str] + List of metric names used for evaluation. + task_type : str + Type of task the benchmark evaluates (e.g., 'text_to_image'). + subsets : list[str] + Optional list of benchmark subset names. + """ name: str display_name: str From 6b0f4f7182b5fd39de462bbfb9fd396ab84efb51 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:20:59 +0100 Subject: [PATCH 04/10] feat: add benchmark discovery functions and expand benchmark registry - Add list_benchmarks() to filter benchmarks by task type - Add get_benchmark_info() to retrieve benchmark metadata - Add COCO, ImageNet, WikiText to benchmark_info registry Co-authored-by: Cursor --- src/pruna/data/__init__.py | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 86e36cd4..315db752 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -192,4 +192,70 @@ class BenchmarkInfo: metrics=["clip", "fvd"], task_type="text_to_video", ), + "COCO": BenchmarkInfo( + name="coco", + display_name="COCO", + description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", + metrics=["fid", "clip", "clip_iqa"], + task_type="text_to_image", + ), + "ImageNet": BenchmarkInfo( + name="imagenet", + display_name="ImageNet", + description="Large-scale image classification benchmark with 1000 classes.", + metrics=["accuracy", "top5_accuracy"], + task_type="image_classification", + ), + "WikiText": BenchmarkInfo( + name="wikitext", + display_name="WikiText", + description="Language modeling benchmark based on Wikipedia articles.", + metrics=["perplexity"], + task_type="text_generation", + ), } + + +def list_benchmarks(task_type: str | None = None) -> list[str]: + """ + List available benchmark names. + + Parameters + ---------- + task_type : str | None + Filter by task type (e.g., 'text_to_image', 'text_to_video'). + If None, returns all benchmarks. + + Returns + ------- + list[str] + List of benchmark names. + """ + if task_type is None: + return list(benchmark_info.keys()) + return [name for name, info in benchmark_info.items() if info.task_type == task_type] + + +def get_benchmark_info(name: str) -> BenchmarkInfo: + """ + Get benchmark metadata by name. + + Parameters + ---------- + name : str + The benchmark name. + + Returns + ------- + BenchmarkInfo + The benchmark metadata. + + Raises + ------ + KeyError + If benchmark name is not found. + """ + if name not in benchmark_info: + available = ", ".join(benchmark_info.keys()) + raise KeyError(f"Benchmark '{name}' not found. Available: {available}") + return benchmark_info[name] From 56f2167391d00b96ba08d0968e6c0be588904a83 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:25:10 +0100 Subject: [PATCH 05/10] fix: use correct metric names from MetricRegistry Update benchmark metrics to match registered names: - clip -> clip_score - clip_iqa -> clipiqa - Remove unimplemented top5_accuracy Co-authored-by: Cursor --- src/pruna/data/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 315db752..ce404911 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -143,7 +143,7 @@ class BenchmarkInfo: "ranging from basic to complex, enabling comprehensive assessment of model capabilities " "across different domains and difficulty levels." ), - metrics=["arniqa", "clip", "clip_iqa", "sharpness"], + metrics=["arniqa", "clip_score", "clipiqa", "sharpness"], task_type="text_to_image", subsets=[ "Abstract", @@ -175,35 +175,35 @@ class BenchmarkInfo: name="drawbench", display_name="DrawBench", description="A comprehensive benchmark for evaluating text-to-image generation models.", - metrics=["clip", "clip_iqa", "sharpness"], + metrics=["clip_score", "clipiqa", "sharpness"], task_type="text_to_image", ), "GenAIBench": BenchmarkInfo( name="genai_bench", display_name="GenAI Bench", description="A benchmark for evaluating generative AI models.", - metrics=["clip", "clip_iqa", "sharpness"], + metrics=["clip_score", "clipiqa", "sharpness"], task_type="text_to_image", ), "VBench": BenchmarkInfo( name="vbench", display_name="VBench", description="A benchmark for evaluating video generation models.", - metrics=["clip", "fvd"], + metrics=["clip_score"], task_type="text_to_video", ), "COCO": BenchmarkInfo( name="coco", display_name="COCO", description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", - metrics=["fid", "clip", "clip_iqa"], + metrics=["fid", "clip_score", "clipiqa"], task_type="text_to_image", ), "ImageNet": BenchmarkInfo( name="imagenet", display_name="ImageNet", description="Large-scale image classification benchmark with 1000 classes.", - metrics=["accuracy", "top5_accuracy"], + metrics=["accuracy"], task_type="image_classification", ), "WikiText": BenchmarkInfo( From c4d946726273e54a3094d7f6b85305b38a1d0ec0 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 17:00:59 +0100 Subject: [PATCH 06/10] feat: add ImgEdit benchmark with edit type subsets Closes #510 - Add setup_imgedit_dataset in datasets/prompt.py - Support subset filter (replace, add, remove, adjust, extract, style, background, compose) - Fetch instructions and judge prompts from GitHub (PKU-YuanGroup/ImgEdit) - Register ImgEdit in base_datasets - Add BenchmarkInfo entry with accuracy metric, task_type image_edit - Add test for loading with subset filter Co-authored-by: Cursor --- src/pruna/data/__init__.py | 10 +++++ src/pruna/data/datasets/prompt.py | 67 +++++++++++++++++++++++++++++++ tests/data/test_datamodule.py | 17 ++++++++ 3 files changed, 94 insertions(+) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index ce404911..af9d72b4 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -29,6 +29,7 @@ from pruna.data.datasets.prompt import ( setup_drawbench_dataset, setup_genai_bench_dataset, + setup_imgedit_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -100,6 +101,7 @@ "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "ImgEdit": (setup_imgedit_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } @@ -213,6 +215,14 @@ class BenchmarkInfo: metrics=["perplexity"], task_type="text_generation", ), + "ImgEdit": BenchmarkInfo( + name="imgedit", + display_name="ImgEdit", + description="Comprehensive image editing benchmark with 8 edit types: replace, add, remove, adjust, extract, style, background, compose.", + metrics=["accuracy"], + task_type="image_edit", + subsets=["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"], + ), } diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 4f275675..295cb27f 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -104,3 +104,70 @@ def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ds = ds.rename_column("Prompt", "text") pruna_logger.info("GenAI-Bench is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds + + +IMGEDIT_SUBSETS = ["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"] + + +def setup_imgedit_dataset( + seed: int, + subset: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the ImgEdit benchmark dataset for image editing evaluation. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + subset : str | None + Filter by edit type. Available: replace, add, remove, adjust, extract, style, + background, compose. If None, returns all subsets. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The ImgEdit dataset (dummy train, dummy val, test). + """ + import json + + import requests + + if subset is not None and subset not in IMGEDIT_SUBSETS: + raise ValueError(f"Invalid subset: {subset}. Must be one of {IMGEDIT_SUBSETS}") + + instructions_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/b3eb8e74d7cd1fd0ce5341eaf9254744a8ab4c0b/Benchmark/Basic/basic_edit.json" + judge_prompts_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/c14480ac5e7b622e08cd8c46f96624a48eb9ab46/Benchmark/Basic/prompts.json" + + instructions = json.loads(requests.get(instructions_url).text) + judge_prompts = json.loads(requests.get(judge_prompts_url).text) + + records = [] + for _, instruction in instructions.items(): + edit_type = instruction.get("edit_type", "") + + if subset is not None and edit_type != subset: + continue + + records.append({ + "text": instruction.get("prompt", ""), + "subset": edit_type, + "image_id": instruction.get("id", ""), + "judge_prompt": judge_prompts.get(edit_type, ""), + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("ImgEdit is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 61550698..a02218fb 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -45,6 +45,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), + pytest.param("ImgEdit", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -96,3 +97,19 @@ def test_parti_prompts_with_category_filter(): assert len(prompts) == 4 assert all(isinstance(p, str) for p in prompts) assert all(aux["Category"] == "Animals" for aux in auxiliaries) + + +@pytest.mark.slow +def test_imgedit_with_subset_filter(): + """Test ImgEdit loading with subset filter.""" + dm = PrunaDataModule.from_string( + "ImgEdit", subset="replace", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["subset"] == "replace" for aux in auxiliaries) + assert all("judge_prompt" in aux for aux in auxiliaries) From 21097aa0f3ac438da32208f5a1fa4c564a615567 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 17:04:31 +0100 Subject: [PATCH 07/10] feat: add GEditBench benchmark with task type subsets Closes #511 - Add setup_gedit_dataset in datasets/prompt.py - Support subset filter (11 task types including background_change, color_alter, etc.) - Fetch data from HuggingFace (stepfun-ai/GEdit-Bench), filter English only - Register GEditBench in base_datasets - Add BenchmarkInfo entry with accuracy metric, task_type image_edit - Add test for loading with subset filter Co-authored-by: Cursor --- [conflicted 7].coverage | Bin 0 -> 69632 bytes src/pruna/data/__init__.py | 13 +++++++ src/pruna/data/datasets/prompt.py | 62 ++++++++++++++++++++++++++++++ tests/data/test_datamodule.py | 16 ++++++++ 4 files changed, 91 insertions(+) create mode 100644 [conflicted 7].coverage diff --git a/ [conflicted 7].coverage b/ [conflicted 7].coverage new file mode 100644 index 0000000000000000000000000000000000000000..27b077c2b3c32d19ae38645285c8e2c0b1f98096 GIT binary patch literal 69632 zcmeHQ33OandA@Jv%|7$)Sc@fDw)HH@mSsyCOY$Z!(#W>rx=HXBZ?Wak)97h5_H3RR z*|9_HH%TA?%5qu~NC_cjJ*-XPKvGH*6=+FFNy4G+A*GF7I0sBfh$$g9brSdg@4S)T zc;P^M2TsC$NBaM_-246i{qBBe?!2AbZ;0o0bv&I-YI(Jn1c@Y(u&NS5Jn%mY{!51o zPUO-BXvuPZty2%_I=L@EcN4$!-Gm+tY^9z4`vZ%7-}NUw@AN(3(mYXkAP&KRU_dY+ z82IC6ps?NVs%vPFZoe(BjVAPbR*UM{q6}=@F*JD7kb2YLh8u>|;yrb-SB0aqQyo;Z z>0@d}&#L3`gs#R@WAUh#kEdd4enKyWp3LcE*wMi~=;|S%tY#9fjXgWKF-j8OrRCGemb*S2O3ZrT8^Ok%uwb^tsluxUP zc#0)6m5b-&>6EJ9qDLq5`q)(y*ny;#4j|8(8o(OC1dSA9nzk(L%tXFQOf$F{c^sTe zrAxHVL`2gNnfaWWOpXHOhbEI5%p8bOG@MUgE9SN0=F?_kf9t%tN?y*Sw=6qLlw*z- zV@;;wM<(^@%lG;#sJ&&;}QK02Kp{X%JW&O;6@o`Nd0CJhAVY z@wW&tb6w&dLbbEtP`$UMbGbNltJrPj)|Bq5Z4+7!TcKQ=I+{)=bS+iv8z9D(9N3KK zROZ8#z0$x8M(#Lb#E!NxpGutzHuE<@r7P7hI7ue8)R@^{VK-bvg|;6nC)z1CA`nq+ zgG!l$DtZ}a3DkUwEmjG?D+8P$YEm03R(H%43*xZ}{R+5ZZ?s9v>-l(6S1|~4h83)1 zT275-b$AGUX?vUfHx&cStw=XYoF%()C|78bTy;yAN`*l$Wj0;KFFfT27EV3mte_w1ACUF$YJBER;f-V;ygM0KAmwMFM}~5DW+g1OtKr z!GK^uFd!HZ3x62PjizOEiGq{>Fm5_*>Y>W!H|;k95u zFd!HZ3O-)>)Qd(w8o)+jwJ?IbUf|I4oagY7aY42E{)VB<$ns*sIHB%X@rBN4DMybx>*@J)vWerqehw{U>Z zCAHiH_`AOa0M~K=%%$v8eBA;tQKdYhwUAJRe|rlswa0;6YDBV9!Gq1X*Q{6!CHJSSld%MYq>m_*Rsf4ys@&TGVzR# zJ5ABloXVOSi;s`PhJ7UPPy?9SS2+UXDRbXG*h<%ft!U+J;l6KZ#1i|*tf(3I zi4E9qpE(|*U}k4!QWMitu&p0=`GcWD3K-(iij>m^zXFj%3lV7YkLg($l7g*+0kFkm z99wP+#=Xf*9`J|#fX}0UEXOb9$zC78ui%g;iF6D`3$YMvrOvaae6N=?GFo8BH+ulN zkE73)rs5_yV0UuBW^|X@wqX~b^3OzVax9MZcaIa`dwDr+VPnJrz(X8>&9UmlaroVy zj0O`r9SRBSt7Ndiqpvb8un9sWmw{gs!pdXk#CsWf@dDuY*#MqLmu13RnBO4*bUO#; ziX8#lZT7TMhtG(U!&j|Uu9uc`PC~-*O zHRk{Ax?H%i68L5F|HVx%h* zeM~NFub{!wXtR!2`*qd%|Li?-p_^kWN?zCAEEiUB#I-cJHN0GAgZi2Ie>JbBoZ((& z{$DjN7q;@?qx~)PyGAj!U`UFa>e}L#~XQi{_ou{7y5YmY-uVyyW~PA4{SzvscmytCKF5Zf7f2Q z(96wf3mZ;NE)4MiHpi-0&i@^}`YO{x`TSqz#d{fg@dD=mc3xdp6eiCKug=8u^~!mH z`2D}%b&W!V7Yqmn1OtKr!GK^uFd!HZ37!Ap>@Mz=NOv4>@y$zCwRa-$5toC|ygJP?vH^ zIiozId`P@9e77zTcAI%C{P=)`JeYc>;INN=CAhslkZ;NpZP|8 zyL{LCHu^fe|LuLwdyn^&_nEBd2zuQn5>B5Tu)prx2DjKlCTxb}w!x3rZiF!EMq)fwd}Qz505`@q5aYT2 z)e=NlMlSiz`Q94AH}XzFeAMjiycj-=A~Km`;NyMP@-+l0nD|w*K|t*2(q6c(`e#=K>}2h1*`bc=C~Rjvz|}i>a0( z8xS9ojW;o$;F#-P2l1=c5yQ4*Eu8IMON@&R7a}Y&1Px;#6d61*id;-dRcj#ZqBX>L z^fmwMzH{dX0kF9b?p@ah<>OKVGIVJJz!w29E|@h=(bW(}T}`G=I^Q~AfHQMf!HqSm zNcd+`&zbYSb1v|W`HT_5b64Wil_dNFd(YG*+pCWT(GT-^uxEGNd#Z<}g32L9`y7=e)X&4#wEvv*Ix2L6c-6@?!jWC^g!vm zReeA@4Ux{@wg3}gO#y(b<_#NKnm!CTW4Ryf=>`+y-GFnO@$}UBx-;ird+y=}EWjq? z({aGjB9>gU4B{+VMvR{|JTd@jz*NXmeF&Wm!QH3HrOn3ai*B~JV1n=I47za8b;pZDJyav9*O;!{zka^gHXg3^&;7uJ6 zJlw%@g`A8S15oI!V!zP@oX zMf}^Gi{Q4hNOCU(UA2%H4qq#r)wdF3&D5*X;ug5r(ZVPUpNhQl!n#F(>PC_o{-X&p z2h;Yh``yh5&`gYQ*8(^@xPTZ(?K+vTLy<5B!5af;g^<`3;Q3oYoq^I~(CbwsY_>|m zm!A8Fx4~^VF1%(o3G9kl1A*)tti<5$>AZkjK<)_0Gj%(f5UjNcz)l`R4#>Gnr+lc{ zfpKJZhF-&gLCjOUETs(3Nj?0(Y9rx;^AU61d}17X1R}#9zy#5gVtc7Fui!q$h6{LV z1061fC_N~rr;25z_<;>_&5I?J2a{J+4UH=o)}+^5%RWT1qB3Kunz7w*Ae5D z)3e|j09Vb1JDJ(gWW4SV-z9-zJg_Rd!=#x!e~Mk-a_Lg>gk7PKQ?3J$`^VX|j4?I>p(gHW=O_1cH4Jt6o_=i#yE+z}P4^kHdQC-AHLI5dDf(xB+;|3=& zd<*KsXHQ)^{mXSPeM=%2jvs;GkDY$${KeN60Msl81Z#B=<1ZyuCubX>A*qt#7W{&n z@y;I7Za)_y$Gq>0+A;sHgPQ*O$&m3hF&Md1C?7(8VG!iTR2C1+WT;k1B|2#zA@=$y znYksgx?>PLf?=XcL2|ovH|D2Xa@xSKVw2pG?D5nP5|MxzGNd#SdXc_NU!*^#r|I|U zcj>q3H|W>sm+2Sir|5(9KKkeMUGz?R8$C+Xbb@O1X1bH!K!<2QT}7ACcG^rEX$=if zC#)M>P=292ulz{)59ME#r<89gQ_9~eUsOJ;d|dgka` zij)Zk1OtKr!GK^uFd!HZ3zzOggpS? zqfgN%=r`$O^sBHB;34{1`bqj%^nUsQ*bDG(dN;j;o}|ZV9`*wqqB1NI>l zn^0^-u>r*(iZF@+6#XdHqgaPxEs8ZL`cSM!u?odX6f00HN70L-2SqoEWhg=@x=?hY z=s>X)#S#?lC>En=L$L_OLKLkiT2M5jSb#!B(S%|?ig_sJqG&_`KQ9rmH=w9TQHNqS zidqykD5_CZp_qjth=QU}Py|r;QTR}JQFu_eQMgbzQ8-Y@DC{U~DA@1+OFd*E{EP|x zJber7+P_m#15<&6{@4BY``7qR`%d~=yia&_&&!^7ctY-PyEAUT>#tpJalYU@CErV? z~8eW!~(sMxj5UR`J*I z@~V0!J(hO~+Y)kNe?>FeVr(5RrYt2lm)AxUdSqP7VN0@hL@sRPoE)g!lTkgAN@H1C zb3iVvOC_9B+aABiN>V|oG%geu_scsi>A zbXNsh)A59sn}}%ne5w~LbXLLwG}L-FszoQTIqitZg+b0WVNBOECH$ooNI^7`&uXb$ zLd#IP-MmLGbn~!d&go#CCV~CMb$jJPFAv;AnvGpU zCE$&`hDvPNt=Kc*DU6gR6WBjg@mfB4^Q`M(*Kw89m}@8TbP8yrhvmXP9vNP42nuX2 z4;CUY)H$YSp(_Jh3a`OVxh?1~lNp@82Y7Yg<@oWW7Q+(euME64ITpw1yNg$qX5L5B z33G^9&8tc?T+7hI$!jqquP5}R4ui=M%6JqTq846jdNbe_*PFM=g%%#Ao~?6(yRE#Y zucdYqHEZd;yphM_37A5G&%XU~p^pa%C7&$>dw3OQvmY#VBelEbLWqaYCT>nUid9?X zRoYD0V*gZ{#n0x|aG2-}Hv&`O=wYlX^Eq>pNNf0cL=)$s%!HO2OX%5nIg-#yDoOQ#c z%HziG|F@HO6Z$#29ln|Wex)byY~Vzo)&ICZ<*)I5*|*F4Q?KD|@O;s;-TiC#1Mcfw zXI+MiIzQ%I?f8ykRDM~$NA9=(#Qt{sV%w9p6ShX_VQD9M5g>o}<6yf?3WFRLI4n!8 z*43=B#ejcf1@OyOnnv5e%9aXPv8H#43*&3Ckt#`HXX3;>f`?BjQ-U?9Z7fY~sa zOZBON_Sw;-;#>$$qj2Pl` zDr?HpiyvwLQ~N4MpmlJPt_NGu%GtVN#4uh52DewvV5y@&FdOXfXur%VRX#)6T?^>l z{8CZ!dS4CTuHuJlX>xa0b8`8z<&sDh;PPn78SYgUD)m`lU@Hesm;n?y9X>U>{0oSB zu%c#4i()YfW_DI4HDwExhZHcxqZKKq?Wzlv2LoV>$2hj!*7QQ+~fxAP7c_N?o!(}>;hE&nOIs5-s1%L zUS3XH*cfpD@DK-JbF4aXJfmkb8hm1{!@434Ru9Nvfk$6uS}1;*xb%I@VLK<@%g~D# z;QT%t!1L&`On3|PJ0yT^=fJ!L6X*Z)d=$CJ#(ZgVf#{}yNLuDIjM&WJiyq9o`Ti0eb(j@!A zMubCAy&@@B-AI#J59ahrnky#IG!CP!!|TA}{z_W3jFm>#f-xQ=UkjV%8%QE+zygoO z1B->S4J7;fzyOama*KiK4J4zh0i5#(AD3?+*}DpC@R*xg5>VVg5?u+%JeqY2@)a9M z_O9UXesKdyd^sTVSZ21ke~k?!>0U6!W0CY4rb=tSxzYxb(H=0!xnNzsfn;wt|C}jq zAlbJJkU4kd<_3~T2%tH)7iAkr_I2_1-P}Mj(h1O Tuple[Dataset, Dataset, Dataset]: + """ + Setup the GEditBench dataset for image editing evaluation. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + subset : str | None + Filter by task type. Available: background_change, color_alter, material_alter, + motion_change, ps_human, style_change, subject_add, subject_remove, subject_replace, + text_change, tone_transfer. If None, returns all subsets. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The GEditBench dataset (dummy train, dummy val, test). + """ + if subset is not None and subset not in GEDIT_SUBSETS: + raise ValueError(f"Invalid subset: {subset}. Must be one of {GEDIT_SUBSETS}") + + task_type_map = {"subject_add": "subject-add", "subject_remove": "subject-remove", "subject_replace": "subject-replace"} + + ds = load_dataset("stepfun-ai/GEdit-Bench")["train"] # type: ignore[index] + ds = ds.filter(lambda x: x["instruction_language"] == "en") + + if subset is not None: + hf_task_type = task_type_map.get(subset, subset) + ds = ds.filter(lambda x, tt=hf_task_type: x["task_type"] == tt) + + records = [] + for row in ds: + task_type = row.get("task_type", "") + subset_name = {v: k for k, v in task_type_map}.get(task_type, task_type) + records.append({ + "text": row.get("instruction", ""), + "subset": subset_name, + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("GEditBench is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index a02218fb..8b9052fb 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -46,6 +46,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), pytest.param("ImgEdit", dict(), marks=pytest.mark.slow), + pytest.param("GEditBench", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -113,3 +114,18 @@ def test_imgedit_with_subset_filter(): assert all(isinstance(p, str) for p in prompts) assert all(aux["subset"] == "replace" for aux in auxiliaries) assert all("judge_prompt" in aux for aux in auxiliaries) + + +@pytest.mark.slow +def test_geditbench_with_subset_filter(): + """Test GEditBench loading with subset filter.""" + dm = PrunaDataModule.from_string( + "GEditBench", subset="background_change", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["subset"] == "background_change" for aux in auxiliaries) From 73532f02dc1104ba55dddd1504430735967b020a Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 17:21:46 +0100 Subject: [PATCH 08/10] fix: rename subset to category for GEditBench + add empty guard + fix linting - Rename subset parameter to category in setup_gedit_dataset - Add empty dataset guard before ds.select([0]) - Fix line too long (E501) and trailing newline (W391) issues - Update tests to use category parameter Co-authored-by: Cursor --- src/pruna/data/datasets/prompt.py | 65 ++++++++++++++++++++----------- tests/data/test_datamodule.py | 29 +++++--------- 2 files changed, 52 insertions(+), 42 deletions(-) diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 0c7aafc4..ee3a382f 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -154,12 +154,14 @@ def setup_imgedit_dataset( if subset is not None and edit_type != subset: continue - records.append({ - "text": instruction.get("prompt", ""), - "subset": edit_type, - "image_id": instruction.get("id", ""), - "judge_prompt": judge_prompts.get(edit_type, ""), - }) + records.append( + { + "text": instruction.get("prompt", ""), + "subset": edit_type, + "image_id": instruction.get("id", ""), + "judge_prompt": judge_prompts.get(edit_type, ""), + } + ) ds = Dataset.from_list(records) ds = ds.shuffle(seed=seed) @@ -171,15 +173,24 @@ def setup_imgedit_dataset( return ds.select([0]), ds.select([0]), ds -GEDIT_SUBSETS = [ - "background_change", "color_alter", "material_alter", "motion_change", "ps_human", - "style_change", "subject_add", "subject_remove", "subject_replace", "text_change", "tone_transfer" +GEDIT_CATEGORIES = [ + "background_change", + "color_alter", + "material_alter", + "motion_change", + "ps_human", + "style_change", + "subject_add", + "subject_remove", + "subject_replace", + "text_change", + "tone_transfer", ] def setup_gedit_dataset( seed: int, - subset: str | None = None, + category: str | None = None, num_samples: int | None = None, ) -> Tuple[Dataset, Dataset, Dataset]: """ @@ -191,10 +202,10 @@ def setup_gedit_dataset( ---------- seed : int The seed to use. - subset : str | None + category : str | None Filter by task type. Available: background_change, color_alter, material_alter, motion_change, ps_human, style_change, subject_add, subject_remove, subject_replace, - text_change, tone_transfer. If None, returns all subsets. + text_change, tone_transfer. If None, returns all categories. num_samples : int | None Maximum number of samples to return. If None, returns all samples. @@ -203,26 +214,32 @@ def setup_gedit_dataset( Tuple[Dataset, Dataset, Dataset] The GEditBench dataset (dummy train, dummy val, test). """ - if subset is not None and subset not in GEDIT_SUBSETS: - raise ValueError(f"Invalid subset: {subset}. Must be one of {GEDIT_SUBSETS}") + if category is not None and category not in GEDIT_CATEGORIES: + raise ValueError(f"Invalid category: {category}. Must be one of {GEDIT_CATEGORIES}") - task_type_map = {"subject_add": "subject-add", "subject_remove": "subject-remove", "subject_replace": "subject-replace"} + task_type_map = { + "subject_add": "subject-add", + "subject_remove": "subject-remove", + "subject_replace": "subject-replace", + } ds = load_dataset("stepfun-ai/GEdit-Bench")["train"] # type: ignore[index] ds = ds.filter(lambda x: x["instruction_language"] == "en") - if subset is not None: - hf_task_type = task_type_map.get(subset, subset) + if category is not None: + hf_task_type = task_type_map.get(category, category) ds = ds.filter(lambda x, tt=hf_task_type: x["task_type"] == tt) records = [] for row in ds: task_type = row.get("task_type", "") - subset_name = {v: k for k, v in task_type_map}.get(task_type, task_type) - records.append({ - "text": row.get("instruction", ""), - "subset": subset_name, - }) + category_name = {v: k for k, v in task_type_map.items()}.get(task_type, task_type) + records.append( + { + "text": row.get("instruction", ""), + "category": category_name, + } + ) ds = Dataset.from_list(records) ds = ds.shuffle(seed=seed) @@ -230,6 +247,8 @@ def setup_gedit_dataset( if num_samples is not None: ds = ds.select(range(min(num_samples, len(ds)))) + if len(ds) == 0: + raise ValueError(f"No samples found for category '{category}'.") + pruna_logger.info("GEditBench is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds - diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 8b9052fb..15f1fce7 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -4,7 +4,6 @@ import torch from transformers import AutoTokenizer -from pruna.data import BenchmarkInfo, benchmark_info from pruna.data.datasets.image import setup_imagenet_dataset from pruna.data.pruna_datamodule import PrunaDataModule @@ -52,13 +51,12 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: """Test the datamodule from a string.""" # get tokenizer if available - tokenizer = collate_fn_args.get("tokenizer", None) + tokenizer = collate_fn_args.get("tokenizer") # get the datamodule from the string datamodule = PrunaDataModule.from_string(dataset_name, collate_fn_args=collate_fn_args, tokenizer=tokenizer) datamodule.limit_datasets(10) - # iterate through the dataloaders iterate_dataloaders(datamodule) @@ -84,13 +82,10 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar iterate_dataloaders(datamodule) - @pytest.mark.slow def test_parti_prompts_with_category_filter(): """Test PartiPrompts loading with category filter.""" - dm = PrunaDataModule.from_string( - "PartiPrompts", category="Animals", dataloader_args={"batch_size": 4} - ) + dm = PrunaDataModule.from_string("PartiPrompts", category="Animals", dataloader_args={"batch_size": 4}) dm.limit_datasets(10) batch = next(iter(dm.test_dataloader())) prompts, auxiliaries = batch @@ -101,31 +96,27 @@ def test_parti_prompts_with_category_filter(): @pytest.mark.slow -def test_imgedit_with_subset_filter(): - """Test ImgEdit loading with subset filter.""" - dm = PrunaDataModule.from_string( - "ImgEdit", subset="replace", dataloader_args={"batch_size": 4} - ) +def test_imgedit_with_category_filter(): + """Test ImgEdit loading with category filter.""" + dm = PrunaDataModule.from_string("ImgEdit", category="replace", dataloader_args={"batch_size": 4}) dm.limit_datasets(10) batch = next(iter(dm.test_dataloader())) prompts, auxiliaries = batch assert len(prompts) == 4 assert all(isinstance(p, str) for p in prompts) - assert all(aux["subset"] == "replace" for aux in auxiliaries) + assert all(aux["category"] == "replace" for aux in auxiliaries) assert all("judge_prompt" in aux for aux in auxiliaries) @pytest.mark.slow -def test_geditbench_with_subset_filter(): - """Test GEditBench loading with subset filter.""" - dm = PrunaDataModule.from_string( - "GEditBench", subset="background_change", dataloader_args={"batch_size": 4} - ) +def test_geditbench_with_category_filter(): + """Test GEditBench loading with category filter.""" + dm = PrunaDataModule.from_string("GEditBench", category="background_change", dataloader_args={"batch_size": 4}) dm.limit_datasets(10) batch = next(iter(dm.test_dataloader())) prompts, auxiliaries = batch assert len(prompts) == 4 assert all(isinstance(p, str) for p in prompts) - assert all(aux["subset"] == "background_change" for aux in auxiliaries) + assert all(aux["category"] == "background_change" for aux in auxiliaries) From df02c6bc1dfef0b513b03999b1b69d6c54f01e83 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 2 Feb 2026 05:26:42 +0100 Subject: [PATCH 09/10] =?UTF-8?q?fix:=20align=20imgedit=20subset=E2=86=92c?= =?UTF-8?q?ategory=20and=20add=20empty=20dataset=20guards?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename subset to category in setup_imgedit_dataset for API consistency - Add empty dataset guard to setup_imgedit_dataset - Add empty dataset guard to setup_parti_prompts_dataset Co-authored-by: Cursor --- src/pruna/data/datasets/prompt.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index ee3a382f..000012fb 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -80,6 +80,10 @@ def setup_parti_prompts_dataset( ds = ds.select(range(min(num_samples, len(ds)))) ds = ds.rename_column("Prompt", "text") + + if len(ds) == 0: + raise ValueError(f"No samples found for category '{category}'.") + pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds @@ -106,12 +110,12 @@ def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: return ds.select([0]), ds.select([0]), ds -IMGEDIT_SUBSETS = ["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"] +IMGEDIT_CATEGORIES = ["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"] def setup_imgedit_dataset( seed: int, - subset: str | None = None, + category: str | None = None, num_samples: int | None = None, ) -> Tuple[Dataset, Dataset, Dataset]: """ @@ -123,9 +127,9 @@ def setup_imgedit_dataset( ---------- seed : int The seed to use. - subset : str | None + category : str | None Filter by edit type. Available: replace, add, remove, adjust, extract, style, - background, compose. If None, returns all subsets. + background, compose. If None, returns all categories. num_samples : int | None Maximum number of samples to return. If None, returns all samples. @@ -138,8 +142,8 @@ def setup_imgedit_dataset( import requests - if subset is not None and subset not in IMGEDIT_SUBSETS: - raise ValueError(f"Invalid subset: {subset}. Must be one of {IMGEDIT_SUBSETS}") + if category is not None and category not in IMGEDIT_CATEGORIES: + raise ValueError(f"Invalid category: {category}. Must be one of {IMGEDIT_CATEGORIES}") instructions_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/b3eb8e74d7cd1fd0ce5341eaf9254744a8ab4c0b/Benchmark/Basic/basic_edit.json" judge_prompts_url = "https://raw.githubusercontent.com/PKU-YuanGroup/ImgEdit/c14480ac5e7b622e08cd8c46f96624a48eb9ab46/Benchmark/Basic/prompts.json" @@ -151,13 +155,13 @@ def setup_imgedit_dataset( for _, instruction in instructions.items(): edit_type = instruction.get("edit_type", "") - if subset is not None and edit_type != subset: + if category is not None and edit_type != category: continue records.append( { "text": instruction.get("prompt", ""), - "subset": edit_type, + "category": edit_type, "image_id": instruction.get("id", ""), "judge_prompt": judge_prompts.get(edit_type, ""), } @@ -169,6 +173,9 @@ def setup_imgedit_dataset( if num_samples is not None: ds = ds.select(range(min(num_samples, len(ds)))) + if len(ds) == 0: + raise ValueError(f"No samples found for category '{category}'.") + pruna_logger.info("ImgEdit is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds From bb10cdf15120285726790906b778a340ab78099d Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 2 Feb 2026 06:05:44 +0100 Subject: [PATCH 10/10] fix: shorten ImgEdit description to fix line length linting Co-authored-by: Cursor --- src/pruna/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 52f6e27c..5dce04cd 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -220,7 +220,7 @@ class BenchmarkInfo: "ImgEdit": BenchmarkInfo( name="imgedit", display_name="ImgEdit", - description="Comprehensive image editing benchmark with 8 edit types: replace, add, remove, adjust, extract, style, background, compose.", + description="Image editing benchmark with 8 edit types for evaluating editing capabilities.", metrics=["accuracy"], task_type="image_edit", subsets=["replace", "add", "remove", "adjust", "extract", "style", "background", "compose"],