From 6db8f0b51522f78a78feb3e4c407939b1e491d84 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 22 Jan 2026 10:58:01 +0100 Subject: [PATCH 1/9] feat: add benchmark support to PrunaDataModule and implement PartiPrompts benchmark - Introduced `from_benchmark` method in `PrunaDataModule` to create instances from benchmark classes. - Added `Benchmark`, `BenchmarkEntry`, and `BenchmarkRegistry` classes for managing benchmarks. - Implemented `PartiPrompts` benchmark for text-to-image generation with various categories and challenges. - Created utility function `benchmark_to_datasets` to convert benchmarks into datasets compatible with `PrunaDataModule`. - Added integration tests for benchmark functionality and data module interactions. --- src/pruna/data/pruna_datamodule.py | 54 +++ src/pruna/evaluation/benchmarks/__init__.py | 24 ++ src/pruna/evaluation/benchmarks/adapter.py | 70 ++++ src/pruna/evaluation/benchmarks/base.py | 86 +++++ src/pruna/evaluation/benchmarks/registry.py | 66 ++++ .../benchmarks/text_to_image/__init__.py | 67 ++++ .../benchmarks/text_to_image/parti.py | 316 ++++++++++++++++++ tests/evaluation/test_benchmarks.py | 158 +++++++++ .../evaluation/test_benchmarks_integration.py | 123 +++++++ 9 files changed, 964 insertions(+) create mode 100644 src/pruna/evaluation/benchmarks/__init__.py create mode 100644 src/pruna/evaluation/benchmarks/adapter.py create mode 100644 src/pruna/evaluation/benchmarks/base.py create mode 100644 src/pruna/evaluation/benchmarks/registry.py create mode 100644 src/pruna/evaluation/benchmarks/text_to_image/__init__.py create mode 100644 src/pruna/evaluation/benchmarks/text_to_image/parti.py create mode 100644 tests/evaluation/test_benchmarks.py create mode 100644 tests/evaluation/test_benchmarks_integration.py diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 435d7eec..30b47b29 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -25,6 +25,9 @@ from transformers.tokenization_utils import PreTrainedTokenizer as AutoTokenizer from pruna.data import base_datasets +from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets +from pruna.evaluation.benchmarks.base import Benchmark +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry from pruna.data.collate import pruna_collate_fns from pruna.data.utils import TokenizerMissingError from pruna.logging.logger import pruna_logger @@ -161,6 +164,13 @@ def from_string( PrunaDataModule The PrunaDataModule. """ + # Check if it's a benchmark first + benchmark_class = BenchmarkRegistry.get(dataset_name) + if benchmark_class is not None: + return cls.from_benchmark( + benchmark_class(seed=seed), tokenizer, collate_fn_args, dataloader_args + ) + setup_fn, collate_fn_name, default_collate_fn_args = base_datasets[dataset_name] # use default collate_fn_args and override with user-provided ones @@ -179,6 +189,50 @@ def from_string( (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args ) + @classmethod + def from_benchmark( + cls, + benchmark: Benchmark, + tokenizer: AutoTokenizer | None = None, + collate_fn_args: dict = dict(), + dataloader_args: dict = dict(), + ) -> "PrunaDataModule": + """ + Create a PrunaDataModule from a Benchmark instance. + + Parameters + ---------- + benchmark : Benchmark + The benchmark instance. + tokenizer : AutoTokenizer | None + The tokenizer to use (if needed for the task type). + collate_fn_args : dict + Any additional arguments for the collate function. + dataloader_args : dict + Any additional arguments for the dataloader. + + Returns + ------- + PrunaDataModule + The PrunaDataModule. + """ + train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) + + # Determine collate function based on task type + task_to_collate = { + "text_to_image": "prompt_collate", + "text_generation": "text_generation_collate", + "audio": "audio_collate", + "image_classification": "image_classification_collate", + "question_answering": "question_answering_collate", + } + + collate_fn_name = task_to_collate.get(benchmark.task_type, "prompt_collate") + + return cls.from_datasets( + (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args + ) + def limit_datasets(self, limit: int | list[int] | tuple[int, int, int]) -> None: """ Limit the dataset to the given number of samples. diff --git a/src/pruna/evaluation/benchmarks/__init__.py b/src/pruna/evaluation/benchmarks/__init__.py new file mode 100644 index 00000000..8a5c9df8 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry + +# Auto-register all benchmarks +from pruna.evaluation.benchmarks import text_to_image # noqa: F401 + +# Auto-register all benchmark subclasses +BenchmarkRegistry.auto_register_subclasses(text_to_image) + +__all__ = ["Benchmark", "BenchmarkEntry", "BenchmarkRegistry", "TASK"] diff --git a/src/pruna/evaluation/benchmarks/adapter.py b/src/pruna/evaluation/benchmarks/adapter.py new file mode 100644 index 00000000..b82bb273 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/adapter.py @@ -0,0 +1,70 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +from datasets import Dataset + +from pruna.evaluation.benchmarks.base import Benchmark +from pruna.logging.logger import pruna_logger + + +def benchmark_to_datasets(benchmark: Benchmark) -> Tuple[Dataset, Dataset, Dataset]: + """ + Convert a Benchmark instance to train/val/test datasets compatible with PrunaDataModule. + + Parameters + ---------- + benchmark : Benchmark + The benchmark instance to convert. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Train, validation, and test datasets. For test-only benchmarks, + train and val are dummy datasets with a single item. + """ + entries = list(benchmark) + + # Convert BenchmarkEntries to dict format expected by datasets + # For prompt-based benchmarks, we need "text" field for prompt_collate + data = [] + for entry in entries: + row = entry.model_inputs.copy() + row.update(entry.additional_info) + + # Ensure "text" field exists for prompt collate functions + if "text" not in row and "prompt" in row: + row["text"] = row["prompt"] + elif "text" not in row: + # If neither exists, use the first string value + for key, value in row.items(): + if isinstance(value, str): + row["text"] = value + break + + # Add path if needed for some collate functions + if "path" not in row: + row["path"] = entry.path + data.append(row) + + dataset = Dataset.from_list(data) + + # For test-only benchmarks (like PartiPrompts), create dummy train/val + pruna_logger.info(f"{benchmark.display_name} is a test-only dataset. Do not use it for training or validation.") + dummy = dataset.select([0]) if len(dataset) > 0 else dataset + + return dummy, dummy, dataset diff --git a/src/pruna/evaluation/benchmarks/base.py b/src/pruna/evaluation/benchmarks/base.py new file mode 100644 index 00000000..62837b03 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/base.py @@ -0,0 +1,86 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Iterator, List, Literal + +TASK = Literal[ + "text_to_image", + "text_generation", + "audio", + "image_classification", + "question_answering", +] + + +@dataclass +class BenchmarkEntry: + """A single entry in a benchmark dataset.""" + + model_inputs: dict[str, Any] + model_outputs: dict[str, Any] = field(default_factory=dict) + path: str = "" + additional_info: dict[str, Any] = field(default_factory=dict) + task_type: TASK = "text_to_image" + + +class Benchmark(ABC): + """Base class for all benchmark datasets.""" + + def __init__(self): + """Initialize the benchmark. Override to load data lazily or eagerly.""" + pass + + @abstractmethod + def __iter__(self) -> Iterator[BenchmarkEntry]: + """Iterate over benchmark entries.""" + pass + + @property + @abstractmethod + def name(self) -> str: + """Return the unique name identifier for this benchmark.""" + pass + + @property + @abstractmethod + def display_name(self) -> str: + """Return the human-readable display name for this benchmark.""" + pass + + @abstractmethod + def __len__(self) -> int: + """Return the number of items in the benchmark.""" + pass + + @property + @abstractmethod + def metrics(self) -> List[str]: + """Return the list of metric names recommended for this benchmark.""" + pass + + @property + @abstractmethod + def task_type(self) -> TASK: + """Return the task type for this benchmark.""" + pass + + @property + @abstractmethod + def description(self) -> str: + """Return a description of this benchmark.""" + pass diff --git a/src/pruna/evaluation/benchmarks/registry.py b/src/pruna/evaluation/benchmarks/registry.py new file mode 100644 index 00000000..2e6399ab --- /dev/null +++ b/src/pruna/evaluation/benchmarks/registry.py @@ -0,0 +1,66 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Type + +from pruna.evaluation.benchmarks.base import Benchmark + + +class BenchmarkRegistry: + """Registry for automatically discovering and registering benchmark classes.""" + + _registry: dict[str, Type[Benchmark]] = {} + + @classmethod + def register(cls, benchmark_class: Type[Benchmark]) -> Type[Benchmark]: + """Register a benchmark class by its name property.""" + # Create instance with default args to get the name + # This assumes benchmarks have default or no required arguments + try: + instance = benchmark_class() + name = instance.name + except Exception as e: + raise ValueError( + f"Failed to create instance of {benchmark_class.__name__} for registration: {e}. " + "Ensure the benchmark class can be instantiated with default arguments." + ) from e + + if name in cls._registry: + raise ValueError(f"Benchmark with name '{name}' is already registered.") + cls._registry[name] = benchmark_class + return benchmark_class + + @classmethod + def get(cls, name: str) -> Type[Benchmark] | None: + """Get a benchmark class by name.""" + return cls._registry.get(name) + + @classmethod + def list_all(cls) -> dict[str, Type[Benchmark]]: + """List all registered benchmarks.""" + return cls._registry.copy() + + @classmethod + def auto_register_subclasses(cls, module) -> None: + """Automatically register all Benchmark subclasses in a module.""" + for name, obj in inspect.getmembers(module, inspect.isclass): + if ( + issubclass(obj, Benchmark) + and obj is not Benchmark + and (obj.__module__ == module.__name__ or obj.__module__.startswith(module.__name__ + ".")) + ): + cls.register(obj) diff --git a/src/pruna/evaluation/benchmarks/text_to_image/__init__.py b/src/pruna/evaluation/benchmarks/text_to_image/__init__.py new file mode 100644 index 00000000..d5f38c69 --- /dev/null +++ b/src/pruna/evaluation/benchmarks/text_to_image/__init__.py @@ -0,0 +1,67 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pruna.evaluation.benchmarks.text_to_image.parti import ( + PartiPrompts, + PartiPromptsAbstract, + PartiPromptsAnimals, + PartiPromptsArtifacts, + PartiPromptsArts, + PartiPromptsBasic, + PartiPromptsComplex, + PartiPromptsFineGrainedDetail, + PartiPromptsFoodBeverage, + PartiPromptsImagination, + PartiPromptsIllustrations, + PartiPromptsIndoorScenes, + PartiPromptsLinguisticStructures, + PartiPromptsOutdoorScenes, + PartiPromptsPeople, + PartiPromptsPerspective, + PartiPromptsProducePlants, + PartiPromptsPropertiesPositioning, + PartiPromptsQuantity, + PartiPromptsSimpleDetail, + PartiPromptsStyleFormat, + PartiPromptsVehicles, + PartiPromptsWorldKnowledge, + PartiPromptsWritingSymbols, +) + +__all__ = [ + "PartiPrompts", + "PartiPromptsAbstract", + "PartiPromptsAnimals", + "PartiPromptsArtifacts", + "PartiPromptsArts", + "PartiPromptsBasic", + "PartiPromptsComplex", + "PartiPromptsFineGrainedDetail", + "PartiPromptsFoodBeverage", + "PartiPromptsImagination", + "PartiPromptsIllustrations", + "PartiPromptsIndoorScenes", + "PartiPromptsLinguisticStructures", + "PartiPromptsOutdoorScenes", + "PartiPromptsPeople", + "PartiPromptsPerspective", + "PartiPromptsProducePlants", + "PartiPromptsPropertiesPositioning", + "PartiPromptsQuantity", + "PartiPromptsSimpleDetail", + "PartiPromptsStyleFormat", + "PartiPromptsVehicles", + "PartiPromptsWorldKnowledge", + "PartiPromptsWritingSymbols", +] diff --git a/src/pruna/evaluation/benchmarks/text_to_image/parti.py b/src/pruna/evaluation/benchmarks/text_to_image/parti.py new file mode 100644 index 00000000..089b64df --- /dev/null +++ b/src/pruna/evaluation/benchmarks/text_to_image/parti.py @@ -0,0 +1,316 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Iterator, List, cast + +from datasets import Dataset, load_dataset + +from pruna.evaluation.benchmarks.base import TASK, Benchmark, BenchmarkEntry + + +class PartiPrompts(Benchmark): + """Parti Prompts benchmark for text-to-image generation.""" + + def __init__( + self, + seed: int = 42, + num_samples: int | None = None, + subset: str | None = None, + ): + """ + Initialize the Parti Prompts benchmark. + + Parameters + ---------- + seed : int + Random seed for shuffling. Default is 42. + num_samples : int | None + Number of samples to select. If None, uses all samples. Default is None. + subset : str | None + Filter by a subset of the dataset. For PartiPrompts, this can be either: + + **Categories:** + - "Abstract" + - "Animals" + - "Artifacts" + - "Arts" + - "Food & Beverage" + - "Illustrations" + - "Indoor Scenes" + - "Outdoor Scenes" + - "People" + - "Produce & Plants" + - "Vehicles" + - "World Knowledge" + + **Challenges:** + - "Basic" + - "Complex" + - "Fine-grained Detail" + - "Imagination" + - "Linguistic Structures" + - "Perspective" + - "Properties & Positioning" + - "Quantity" + - "Simple Detail" + - "Style & Format" + - "Writing & Symbols" + + If None, includes all samples. Default is None. + """ + super().__init__() + self._seed = seed + self._num_samples = num_samples + + # Determine if subset refers to a dataset category or challenge + # Check against known challenges + self.subset = subset + + def _load_prompts(self) -> List[dict]: + """Load prompts from the dataset.""" + dataset_dict = load_dataset("nateraw/parti-prompts") # type: ignore + dataset = cast(Dataset, dataset_dict["train"]) # type: ignore + if self.subset is not None: + dataset = dataset.filter(lambda x: x["Category"] == self.subset or x["Challenge"] == self.subset) + shuffled_dataset = dataset.shuffle(seed=self._seed) + if self._num_samples is not None: + selected_dataset = shuffled_dataset.select(range(min(self._num_samples, len(shuffled_dataset)))) + else: + selected_dataset = shuffled_dataset + return list(selected_dataset) + + def __iter__(self) -> Iterator[BenchmarkEntry]: + """Iterate over benchmark entries.""" + for i, row in enumerate(self._load_prompts()): + yield BenchmarkEntry( + model_inputs={"prompt": row["Prompt"]}, + model_outputs={}, + path=f"{i}.png", + additional_info={ + "category": row["Category"], + "challenge": row["Challenge"], + "note": row.get("Note", ""), + }, + task_type=self.task_type, + ) + + @property + def name(self) -> str: + """Return the unique name identifier.""" + if self.subset is None: + return "parti_prompts" + normalized = ( + self.subset.lower().replace(" & ", "_").replace(" ", "_").replace("&", "_").replace("__", "_").rstrip("_") + ) + return f"parti_prompts_{normalized}" + + @property + def display_name(self) -> str: + """Return the human-readable display name.""" + if self.subset is None: + return "Parti Prompts" + return f"Parti Prompts ({self.subset})" + + def __len__(self) -> int: + """Return the number of entries in the benchmark.""" + return len(self._load_prompts()) + + @property + def metrics(self) -> List[str]: + """Return the list of recommended metrics.""" + return ["arniqa", "clip", "clip_iqa", "sharpness"] + + @property + def task_type(self) -> TASK: + """Return the task type.""" + return "text_to_image" + + @property + def description(self) -> str: + """Return a description of the benchmark.""" + return ( + "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " + "ranging from basic to complex, enabling comprehensive assessment of model capabilities " + "across different domains and difficulty levels." + ) + + +# Category-based subclasses +class PartiPromptsAbstract(PartiPrompts): + """Parti Prompts filtered by Abstract category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Abstract" if subset is None else subset) + + +class PartiPromptsAnimals(PartiPrompts): + """Parti Prompts filtered by Animals category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Animals" if subset is None else subset) + + +class PartiPromptsArtifacts(PartiPrompts): + """Parti Prompts filtered by Artifacts category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Artifacts" if subset is None else subset) + + +class PartiPromptsArts(PartiPrompts): + """Parti Prompts filtered by Arts category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Arts" if subset is None else subset) + + +class PartiPromptsFoodBeverage(PartiPrompts): + """Parti Prompts filtered by Food & Beverage category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Food & Beverage" if subset is None else subset) + + +class PartiPromptsIllustrations(PartiPrompts): + """Parti Prompts filtered by Illustrations category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Illustrations" if subset is None else subset) + + +class PartiPromptsIndoorScenes(PartiPrompts): + """Parti Prompts filtered by Indoor Scenes category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Indoor Scenes" if subset is None else subset) + + +class PartiPromptsOutdoorScenes(PartiPrompts): + """Parti Prompts filtered by Outdoor Scenes category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Outdoor Scenes" if subset is None else subset) + + +class PartiPromptsPeople(PartiPrompts): + """Parti Prompts filtered by People category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="People" if subset is None else subset) + + +class PartiPromptsProducePlants(PartiPrompts): + """Parti Prompts filtered by Produce & Plants category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Produce & Plants" if subset is None else subset) + + +class PartiPromptsVehicles(PartiPrompts): + """Parti Prompts filtered by Vehicles category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Vehicles" if subset is None else subset) + + +class PartiPromptsWorldKnowledge(PartiPrompts): + """Parti Prompts filtered by World Knowledge category.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="World Knowledge" if subset is None else subset) + + +# Challenge-based subclasses +class PartiPromptsBasic(PartiPrompts): + """Parti Prompts filtered by Basic challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + # subset can be a category to further filter when challenge is already set + super().__init__(seed=seed, num_samples=num_samples, subset="Basic" if subset is None else subset) + + +class PartiPromptsComplex(PartiPrompts): + """Parti Prompts filtered by Complex challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Complex" if subset is None else subset) + + +class PartiPromptsFineGrainedDetail(PartiPrompts): + """Parti Prompts filtered by Fine-grained Detail challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Fine-grained Detail" if subset is None else subset) + + +class PartiPromptsImagination(PartiPrompts): + """Parti Prompts filtered by Imagination challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Imagination" if subset is None else subset) + + +class PartiPromptsLinguisticStructures(PartiPrompts): + """Parti Prompts filtered by Linguistic Structures challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__( + seed=seed, num_samples=num_samples, subset="Linguistic Structures" if subset is None else subset + ) + + +class PartiPromptsPerspective(PartiPrompts): + """Parti Prompts filtered by Perspective challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Perspective" if subset is None else subset) + + +class PartiPromptsPropertiesPositioning(PartiPrompts): + """Parti Prompts filtered by Properties & Positioning challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__( + seed=seed, num_samples=num_samples, subset="Properties & Positioning" if subset is None else subset + ) + + +class PartiPromptsQuantity(PartiPrompts): + """Parti Prompts filtered by Quantity challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Quantity" if subset is None else subset) + + +class PartiPromptsSimpleDetail(PartiPrompts): + """Parti Prompts filtered by Simple Detail challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Simple Detail" if subset is None else subset) + + +class PartiPromptsStyleFormat(PartiPrompts): + """Parti Prompts filtered by Style & Format challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Style & Format" if subset is None else subset) + + +class PartiPromptsWritingSymbols(PartiPrompts): + """Parti Prompts filtered by Writing & Symbols challenge.""" + + def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): + super().__init__(seed=seed, num_samples=num_samples, subset="Writing & Symbols" if subset is None else subset) diff --git a/tests/evaluation/test_benchmarks.py b/tests/evaluation/test_benchmarks.py new file mode 100644 index 00000000..d9345874 --- /dev/null +++ b/tests/evaluation/test_benchmarks.py @@ -0,0 +1,158 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the benchmarks module.""" + +import pytest + +from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry +from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets +from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts + + +def test_benchmark_entry_creation(): + """Test creating a BenchmarkEntry with all fields.""" + entry = BenchmarkEntry( + model_inputs={"prompt": "test prompt"}, + model_outputs={"image": "test_image.png"}, + path="test.png", + additional_info={"category": "test"}, + task_type="text_to_image", + ) + + assert entry.model_inputs == {"prompt": "test prompt"} + assert entry.model_outputs == {"image": "test_image.png"} + assert entry.path == "test.png" + assert entry.additional_info == {"category": "test"} + assert entry.task_type == "text_to_image" + + +def test_benchmark_entry_defaults(): + """Test BenchmarkEntry with default values.""" + entry = BenchmarkEntry(model_inputs={"prompt": "test"}) + + assert entry.model_inputs == {"prompt": "test"} + assert entry.model_outputs == {} + assert entry.path == "" + assert entry.additional_info == {} + assert entry.task_type == "text_to_image" + + +def test_task_type_literal(): + """Test that TASK type only accepts valid task types.""" + # Valid task types + valid_tasks: list[TASK] = [ + "text_to_image", + "text_generation", + "audio", + "image_classification", + "question_answering", + ] + + for task in valid_tasks: + entry = BenchmarkEntry(model_inputs={}, task_type=task) + assert entry.task_type == task + + +def test_benchmark_registry_get(): + """Test getting a benchmark from the registry.""" + benchmark_class = BenchmarkRegistry.get("parti_prompts") + assert benchmark_class is not None + assert issubclass(benchmark_class, Benchmark) + + +def test_benchmark_registry_list_all(): + """Test listing all registered benchmarks.""" + all_benchmarks = BenchmarkRegistry.list_all() + assert isinstance(all_benchmarks, dict) + assert len(all_benchmarks) > 0 + assert "parti_prompts" in all_benchmarks + + +def test_benchmark_registry_get_nonexistent(): + """Test getting a non-existent benchmark returns None.""" + benchmark_class = BenchmarkRegistry.get("nonexistent_benchmark") + assert benchmark_class is None + + +def test_parti_prompts_creation(): + """Test creating a PartiPrompts benchmark instance.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + + assert benchmark.name == "parti_prompts" + assert benchmark.display_name == "Parti Prompts" + assert benchmark.task_type == "text_to_image" + assert len(benchmark.metrics) > 0 + assert isinstance(benchmark.description, str) + + +def test_parti_prompts_iteration(): + """Test iterating over PartiPrompts entries.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + entries = list(benchmark) + + assert len(entries) == 5 + for entry in entries: + assert isinstance(entry, BenchmarkEntry) + assert "prompt" in entry.model_inputs + assert entry.task_type == "text_to_image" + assert entry.model_outputs == {} + + +def test_parti_prompts_length(): + """Test PartiPrompts __len__ method.""" + benchmark = PartiPrompts(seed=42, num_samples=10) + assert len(benchmark) == 10 + + +def test_parti_prompts_subset(): + """Test PartiPrompts with a subset filter.""" + benchmark = PartiPrompts(seed=42, num_samples=5, subset="Animals") + + assert "animals" in benchmark.name.lower() + assert "Animals" in benchmark.display_name + + entries = list(benchmark) + for entry in entries: + assert entry.additional_info.get("category") == "Animals" + + +def test_benchmark_to_datasets(): + """Test converting a benchmark to datasets.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) + + assert len(test_ds) == 5 + assert len(train_ds) == 1 # Dummy dataset + assert len(val_ds) == 1 # Dummy dataset + + # Check that test dataset has the expected fields + sample = test_ds[0] + assert "prompt" in sample or "text" in sample + + +def test_benchmark_entry_task_type_validation(): + """Test that BenchmarkEntry validates task_type.""" + # This should work + entry = BenchmarkEntry( + model_inputs={}, + task_type="text_to_image", + ) + assert entry.task_type == "text_to_image" + + # Test other valid task types + for task in ["text_generation", "audio", "image_classification", "question_answering"]: + entry = BenchmarkEntry(model_inputs={}, task_type=task) + assert entry.task_type == task diff --git a/tests/evaluation/test_benchmarks_integration.py b/tests/evaluation/test_benchmarks_integration.py new file mode 100644 index 00000000..e7198f07 --- /dev/null +++ b/tests/evaluation/test_benchmarks_integration.py @@ -0,0 +1,123 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for benchmarks with datamodule and metrics.""" + +import pytest + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.benchmarks.registry import BenchmarkRegistry +from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts +from pruna.evaluation.metrics.registry import MetricRegistry + + +@pytest.mark.cpu +def test_datamodule_from_benchmark(): + """Test creating a PrunaDataModule from a benchmark.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + datamodule = PrunaDataModule.from_benchmark(benchmark) + + assert datamodule is not None + assert datamodule.test_dataset is not None + assert len(datamodule.test_dataset) == 5 + + +@pytest.mark.cpu +def test_datamodule_from_benchmark_string(): + """Test creating a PrunaDataModule from a benchmark name string.""" + datamodule = PrunaDataModule.from_string("parti_prompts", seed=42) + + assert datamodule is not None + # Limit to small number for testing + datamodule.limit_datasets(5) + + # Test that we can iterate through the dataloader + test_loader = datamodule.test_dataloader(batch_size=2) + batch = next(iter(test_loader)) + assert batch is not None + + +@pytest.mark.cpu +def test_benchmark_with_metrics(): + """Test that benchmarks provide recommended metrics.""" + benchmark = PartiPrompts(seed=42, num_samples=5) + recommended_metrics = benchmark.metrics + + assert isinstance(recommended_metrics, list) + assert len(recommended_metrics) > 0 + + # Check that metrics can be retrieved from registry + for metric_name in recommended_metrics: + # Some metrics might be registered, some might not + # Just verify the names are strings + assert isinstance(metric_name, str) + + +@pytest.mark.cpu +def test_benchmark_registry_integration(): + """Test that benchmarks are properly registered and can be used.""" + # Get benchmark from registry + benchmark_class = BenchmarkRegistry.get("parti_prompts") + assert benchmark_class is not None + + # Create instance + benchmark = benchmark_class(seed=42, num_samples=3) + + # Verify it works with datamodule + datamodule = PrunaDataModule.from_benchmark(benchmark) + assert datamodule is not None + + # Verify we can get entries + entries = list(benchmark) + assert len(entries) == 3 + + +@pytest.mark.cpu +def test_benchmark_task_type_mapping(): + """Test that benchmark task types map correctly to collate functions.""" + benchmark = PartiPrompts(seed=42, num_samples=3) + + # Create datamodule and verify it uses the correct collate function + datamodule = PrunaDataModule.from_benchmark(benchmark) + + # The collate function should be set based on task_type + assert datamodule.collate_fn is not None + + # Verify we can use the dataloader + test_loader = datamodule.test_dataloader(batch_size=1) + batch = next(iter(test_loader)) + assert batch is not None + + +@pytest.mark.cpu +def test_benchmark_entry_model_outputs(): + """Test that BenchmarkEntry can store model outputs.""" + from pruna.evaluation.benchmarks.base import BenchmarkEntry + + entry = BenchmarkEntry( + model_inputs={"prompt": "test"}, + model_outputs={"image": "generated_image.png", "score": 0.95}, + ) + + assert entry.model_outputs == {"image": "generated_image.png", "score": 0.95} + + # Verify entries from benchmark have empty model_outputs by default + benchmark = PartiPrompts(seed=42, num_samples=2) + entries = list(benchmark) + + for entry in entries: + assert entry.model_outputs == {} + # But model_outputs field exists and can be populated + entry.model_outputs["test"] = "value" + assert entry.model_outputs["test"] == "value" From 7c53c95a654528b4851ba8f93dbc7bb0aa7339d8 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 15:50:16 +0100 Subject: [PATCH 2/9] refactor: simplify benchmark system, extend PartiPrompts with subset filtering - Remove heavy benchmark abstraction (Benchmark class, registry, adapter, 24 subclasses) - Extend setup_parti_prompts_dataset with category and num_samples params - Add BenchmarkInfo dataclass for metadata (metrics, description, subsets) - Switch PartiPrompts to prompt_with_auxiliaries_collate to preserve Category/Challenge - Merge tests into test_datamodule.py Reduces 964 lines to 128 lines (87% reduction) Co-authored-by: Cursor --- src/pruna/data/__init__.py | 76 ++++- src/pruna/data/datasets/prompt.py | 25 +- src/pruna/data/pruna_datamodule.py | 54 --- src/pruna/evaluation/benchmarks/__init__.py | 24 -- src/pruna/evaluation/benchmarks/adapter.py | 70 ---- src/pruna/evaluation/benchmarks/base.py | 86 ----- src/pruna/evaluation/benchmarks/registry.py | 66 ---- .../benchmarks/text_to_image/__init__.py | 67 ---- .../benchmarks/text_to_image/parti.py | 316 ------------------ tests/data/test_datamodule.py | 22 +- tests/evaluation/test_benchmarks.py | 158 --------- .../evaluation/test_benchmarks_integration.py | 123 ------- 12 files changed, 117 insertions(+), 970 deletions(-) delete mode 100644 src/pruna/evaluation/benchmarks/__init__.py delete mode 100644 src/pruna/evaluation/benchmarks/adapter.py delete mode 100644 src/pruna/evaluation/benchmarks/base.py delete mode 100644 src/pruna/evaluation/benchmarks/registry.py delete mode 100644 src/pruna/evaluation/benchmarks/text_to_image/__init__.py delete mode 100644 src/pruna/evaluation/benchmarks/text_to_image/parti.py delete mode 100644 tests/evaluation/test_benchmarks.py delete mode 100644 tests/evaluation/test_benchmarks_integration.py diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 3a811868..820d1262 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass, field from functools import partial from typing import Any, Callable, Tuple @@ -97,8 +98,81 @@ {"img_size": 224}, ), "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), - "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), + "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } + + +@dataclass +class BenchmarkInfo: + """Metadata for a benchmark dataset.""" + + name: str + display_name: str + description: str + metrics: list[str] + task_type: str + subsets: list[str] = field(default_factory=list) + + +benchmark_info: dict[str, BenchmarkInfo] = { + "PartiPrompts": BenchmarkInfo( + name="parti_prompts", + display_name="Parti Prompts", + description=( + "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " + "ranging from basic to complex, enabling comprehensive assessment of model capabilities " + "across different domains and difficulty levels." + ), + metrics=["arniqa", "clip", "clip_iqa", "sharpness"], + task_type="text_to_image", + subsets=[ + "Abstract", + "Animals", + "Artifacts", + "Arts", + "Food & Beverage", + "Illustrations", + "Indoor Scenes", + "Outdoor Scenes", + "People", + "Produce & Plants", + "Vehicles", + "World Knowledge", + "Basic", + "Complex", + "Fine-grained Detail", + "Imagination", + "Linguistic Structures", + "Perspective", + "Properties & Positioning", + "Quantity", + "Simple Detail", + "Style & Format", + "Writing & Symbols", + ], + ), + "DrawBench": BenchmarkInfo( + name="drawbench", + display_name="DrawBench", + description="A comprehensive benchmark for evaluating text-to-image generation models.", + metrics=["clip", "clip_iqa", "sharpness"], + task_type="text_to_image", + ), + "GenAIBench": BenchmarkInfo( + name="genai_bench", + display_name="GenAI Bench", + description="A benchmark for evaluating generative AI models.", + metrics=["clip", "clip_iqa", "sharpness"], + task_type="text_to_image", + ), + "VBench": BenchmarkInfo( + name="vbench", + display_name="VBench", + description="A benchmark for evaluating video generation models.", + metrics=["clip", "fvd"], + task_type="text_to_video", + ), +} diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 1f6fab71..4f275675 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: return ds.select([0]), ds.select([0]), ds -def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def setup_parti_prompts_dataset( + seed: int, + category: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the Parti Prompts dataset. @@ -51,13 +55,30 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ---------- seed : int The seed to use. + category : str | None + Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts, + Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People, + Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex, + Fine-grained Detail, Imagination, Linguistic Structures, Perspective, + Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. Returns ------- Tuple[Dataset, Dataset, Dataset] - The Parti Prompts dataset. + The Parti Prompts dataset (dummy train, dummy val, test). """ ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] + + if category is not None: + ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category) + + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + ds = ds.rename_column("Prompt", "text") pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 30b47b29..435d7eec 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -25,9 +25,6 @@ from transformers.tokenization_utils import PreTrainedTokenizer as AutoTokenizer from pruna.data import base_datasets -from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets -from pruna.evaluation.benchmarks.base import Benchmark -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry from pruna.data.collate import pruna_collate_fns from pruna.data.utils import TokenizerMissingError from pruna.logging.logger import pruna_logger @@ -164,13 +161,6 @@ def from_string( PrunaDataModule The PrunaDataModule. """ - # Check if it's a benchmark first - benchmark_class = BenchmarkRegistry.get(dataset_name) - if benchmark_class is not None: - return cls.from_benchmark( - benchmark_class(seed=seed), tokenizer, collate_fn_args, dataloader_args - ) - setup_fn, collate_fn_name, default_collate_fn_args = base_datasets[dataset_name] # use default collate_fn_args and override with user-provided ones @@ -189,50 +179,6 @@ def from_string( (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args ) - @classmethod - def from_benchmark( - cls, - benchmark: Benchmark, - tokenizer: AutoTokenizer | None = None, - collate_fn_args: dict = dict(), - dataloader_args: dict = dict(), - ) -> "PrunaDataModule": - """ - Create a PrunaDataModule from a Benchmark instance. - - Parameters - ---------- - benchmark : Benchmark - The benchmark instance. - tokenizer : AutoTokenizer | None - The tokenizer to use (if needed for the task type). - collate_fn_args : dict - Any additional arguments for the collate function. - dataloader_args : dict - Any additional arguments for the dataloader. - - Returns - ------- - PrunaDataModule - The PrunaDataModule. - """ - train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) - - # Determine collate function based on task type - task_to_collate = { - "text_to_image": "prompt_collate", - "text_generation": "text_generation_collate", - "audio": "audio_collate", - "image_classification": "image_classification_collate", - "question_answering": "question_answering_collate", - } - - collate_fn_name = task_to_collate.get(benchmark.task_type, "prompt_collate") - - return cls.from_datasets( - (train_ds, val_ds, test_ds), collate_fn_name, tokenizer, collate_fn_args, dataloader_args - ) - def limit_datasets(self, limit: int | list[int] | tuple[int, int, int]) -> None: """ Limit the dataset to the given number of samples. diff --git a/src/pruna/evaluation/benchmarks/__init__.py b/src/pruna/evaluation/benchmarks/__init__.py deleted file mode 100644 index 8a5c9df8..00000000 --- a/src/pruna/evaluation/benchmarks/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry - -# Auto-register all benchmarks -from pruna.evaluation.benchmarks import text_to_image # noqa: F401 - -# Auto-register all benchmark subclasses -BenchmarkRegistry.auto_register_subclasses(text_to_image) - -__all__ = ["Benchmark", "BenchmarkEntry", "BenchmarkRegistry", "TASK"] diff --git a/src/pruna/evaluation/benchmarks/adapter.py b/src/pruna/evaluation/benchmarks/adapter.py deleted file mode 100644 index b82bb273..00000000 --- a/src/pruna/evaluation/benchmarks/adapter.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Tuple - -from datasets import Dataset - -from pruna.evaluation.benchmarks.base import Benchmark -from pruna.logging.logger import pruna_logger - - -def benchmark_to_datasets(benchmark: Benchmark) -> Tuple[Dataset, Dataset, Dataset]: - """ - Convert a Benchmark instance to train/val/test datasets compatible with PrunaDataModule. - - Parameters - ---------- - benchmark : Benchmark - The benchmark instance to convert. - - Returns - ------- - Tuple[Dataset, Dataset, Dataset] - Train, validation, and test datasets. For test-only benchmarks, - train and val are dummy datasets with a single item. - """ - entries = list(benchmark) - - # Convert BenchmarkEntries to dict format expected by datasets - # For prompt-based benchmarks, we need "text" field for prompt_collate - data = [] - for entry in entries: - row = entry.model_inputs.copy() - row.update(entry.additional_info) - - # Ensure "text" field exists for prompt collate functions - if "text" not in row and "prompt" in row: - row["text"] = row["prompt"] - elif "text" not in row: - # If neither exists, use the first string value - for key, value in row.items(): - if isinstance(value, str): - row["text"] = value - break - - # Add path if needed for some collate functions - if "path" not in row: - row["path"] = entry.path - data.append(row) - - dataset = Dataset.from_list(data) - - # For test-only benchmarks (like PartiPrompts), create dummy train/val - pruna_logger.info(f"{benchmark.display_name} is a test-only dataset. Do not use it for training or validation.") - dummy = dataset.select([0]) if len(dataset) > 0 else dataset - - return dummy, dummy, dataset diff --git a/src/pruna/evaluation/benchmarks/base.py b/src/pruna/evaluation/benchmarks/base.py deleted file mode 100644 index 62837b03..00000000 --- a/src/pruna/evaluation/benchmarks/base.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Iterator, List, Literal - -TASK = Literal[ - "text_to_image", - "text_generation", - "audio", - "image_classification", - "question_answering", -] - - -@dataclass -class BenchmarkEntry: - """A single entry in a benchmark dataset.""" - - model_inputs: dict[str, Any] - model_outputs: dict[str, Any] = field(default_factory=dict) - path: str = "" - additional_info: dict[str, Any] = field(default_factory=dict) - task_type: TASK = "text_to_image" - - -class Benchmark(ABC): - """Base class for all benchmark datasets.""" - - def __init__(self): - """Initialize the benchmark. Override to load data lazily or eagerly.""" - pass - - @abstractmethod - def __iter__(self) -> Iterator[BenchmarkEntry]: - """Iterate over benchmark entries.""" - pass - - @property - @abstractmethod - def name(self) -> str: - """Return the unique name identifier for this benchmark.""" - pass - - @property - @abstractmethod - def display_name(self) -> str: - """Return the human-readable display name for this benchmark.""" - pass - - @abstractmethod - def __len__(self) -> int: - """Return the number of items in the benchmark.""" - pass - - @property - @abstractmethod - def metrics(self) -> List[str]: - """Return the list of metric names recommended for this benchmark.""" - pass - - @property - @abstractmethod - def task_type(self) -> TASK: - """Return the task type for this benchmark.""" - pass - - @property - @abstractmethod - def description(self) -> str: - """Return a description of this benchmark.""" - pass diff --git a/src/pruna/evaluation/benchmarks/registry.py b/src/pruna/evaluation/benchmarks/registry.py deleted file mode 100644 index 2e6399ab..00000000 --- a/src/pruna/evaluation/benchmarks/registry.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import inspect -from typing import Type - -from pruna.evaluation.benchmarks.base import Benchmark - - -class BenchmarkRegistry: - """Registry for automatically discovering and registering benchmark classes.""" - - _registry: dict[str, Type[Benchmark]] = {} - - @classmethod - def register(cls, benchmark_class: Type[Benchmark]) -> Type[Benchmark]: - """Register a benchmark class by its name property.""" - # Create instance with default args to get the name - # This assumes benchmarks have default or no required arguments - try: - instance = benchmark_class() - name = instance.name - except Exception as e: - raise ValueError( - f"Failed to create instance of {benchmark_class.__name__} for registration: {e}. " - "Ensure the benchmark class can be instantiated with default arguments." - ) from e - - if name in cls._registry: - raise ValueError(f"Benchmark with name '{name}' is already registered.") - cls._registry[name] = benchmark_class - return benchmark_class - - @classmethod - def get(cls, name: str) -> Type[Benchmark] | None: - """Get a benchmark class by name.""" - return cls._registry.get(name) - - @classmethod - def list_all(cls) -> dict[str, Type[Benchmark]]: - """List all registered benchmarks.""" - return cls._registry.copy() - - @classmethod - def auto_register_subclasses(cls, module) -> None: - """Automatically register all Benchmark subclasses in a module.""" - for name, obj in inspect.getmembers(module, inspect.isclass): - if ( - issubclass(obj, Benchmark) - and obj is not Benchmark - and (obj.__module__ == module.__name__ or obj.__module__.startswith(module.__name__ + ".")) - ): - cls.register(obj) diff --git a/src/pruna/evaluation/benchmarks/text_to_image/__init__.py b/src/pruna/evaluation/benchmarks/text_to_image/__init__.py deleted file mode 100644 index d5f38c69..00000000 --- a/src/pruna/evaluation/benchmarks/text_to_image/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pruna.evaluation.benchmarks.text_to_image.parti import ( - PartiPrompts, - PartiPromptsAbstract, - PartiPromptsAnimals, - PartiPromptsArtifacts, - PartiPromptsArts, - PartiPromptsBasic, - PartiPromptsComplex, - PartiPromptsFineGrainedDetail, - PartiPromptsFoodBeverage, - PartiPromptsImagination, - PartiPromptsIllustrations, - PartiPromptsIndoorScenes, - PartiPromptsLinguisticStructures, - PartiPromptsOutdoorScenes, - PartiPromptsPeople, - PartiPromptsPerspective, - PartiPromptsProducePlants, - PartiPromptsPropertiesPositioning, - PartiPromptsQuantity, - PartiPromptsSimpleDetail, - PartiPromptsStyleFormat, - PartiPromptsVehicles, - PartiPromptsWorldKnowledge, - PartiPromptsWritingSymbols, -) - -__all__ = [ - "PartiPrompts", - "PartiPromptsAbstract", - "PartiPromptsAnimals", - "PartiPromptsArtifacts", - "PartiPromptsArts", - "PartiPromptsBasic", - "PartiPromptsComplex", - "PartiPromptsFineGrainedDetail", - "PartiPromptsFoodBeverage", - "PartiPromptsImagination", - "PartiPromptsIllustrations", - "PartiPromptsIndoorScenes", - "PartiPromptsLinguisticStructures", - "PartiPromptsOutdoorScenes", - "PartiPromptsPeople", - "PartiPromptsPerspective", - "PartiPromptsProducePlants", - "PartiPromptsPropertiesPositioning", - "PartiPromptsQuantity", - "PartiPromptsSimpleDetail", - "PartiPromptsStyleFormat", - "PartiPromptsVehicles", - "PartiPromptsWorldKnowledge", - "PartiPromptsWritingSymbols", -] diff --git a/src/pruna/evaluation/benchmarks/text_to_image/parti.py b/src/pruna/evaluation/benchmarks/text_to_image/parti.py deleted file mode 100644 index 089b64df..00000000 --- a/src/pruna/evaluation/benchmarks/text_to_image/parti.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Iterator, List, cast - -from datasets import Dataset, load_dataset - -from pruna.evaluation.benchmarks.base import TASK, Benchmark, BenchmarkEntry - - -class PartiPrompts(Benchmark): - """Parti Prompts benchmark for text-to-image generation.""" - - def __init__( - self, - seed: int = 42, - num_samples: int | None = None, - subset: str | None = None, - ): - """ - Initialize the Parti Prompts benchmark. - - Parameters - ---------- - seed : int - Random seed for shuffling. Default is 42. - num_samples : int | None - Number of samples to select. If None, uses all samples. Default is None. - subset : str | None - Filter by a subset of the dataset. For PartiPrompts, this can be either: - - **Categories:** - - "Abstract" - - "Animals" - - "Artifacts" - - "Arts" - - "Food & Beverage" - - "Illustrations" - - "Indoor Scenes" - - "Outdoor Scenes" - - "People" - - "Produce & Plants" - - "Vehicles" - - "World Knowledge" - - **Challenges:** - - "Basic" - - "Complex" - - "Fine-grained Detail" - - "Imagination" - - "Linguistic Structures" - - "Perspective" - - "Properties & Positioning" - - "Quantity" - - "Simple Detail" - - "Style & Format" - - "Writing & Symbols" - - If None, includes all samples. Default is None. - """ - super().__init__() - self._seed = seed - self._num_samples = num_samples - - # Determine if subset refers to a dataset category or challenge - # Check against known challenges - self.subset = subset - - def _load_prompts(self) -> List[dict]: - """Load prompts from the dataset.""" - dataset_dict = load_dataset("nateraw/parti-prompts") # type: ignore - dataset = cast(Dataset, dataset_dict["train"]) # type: ignore - if self.subset is not None: - dataset = dataset.filter(lambda x: x["Category"] == self.subset or x["Challenge"] == self.subset) - shuffled_dataset = dataset.shuffle(seed=self._seed) - if self._num_samples is not None: - selected_dataset = shuffled_dataset.select(range(min(self._num_samples, len(shuffled_dataset)))) - else: - selected_dataset = shuffled_dataset - return list(selected_dataset) - - def __iter__(self) -> Iterator[BenchmarkEntry]: - """Iterate over benchmark entries.""" - for i, row in enumerate(self._load_prompts()): - yield BenchmarkEntry( - model_inputs={"prompt": row["Prompt"]}, - model_outputs={}, - path=f"{i}.png", - additional_info={ - "category": row["Category"], - "challenge": row["Challenge"], - "note": row.get("Note", ""), - }, - task_type=self.task_type, - ) - - @property - def name(self) -> str: - """Return the unique name identifier.""" - if self.subset is None: - return "parti_prompts" - normalized = ( - self.subset.lower().replace(" & ", "_").replace(" ", "_").replace("&", "_").replace("__", "_").rstrip("_") - ) - return f"parti_prompts_{normalized}" - - @property - def display_name(self) -> str: - """Return the human-readable display name.""" - if self.subset is None: - return "Parti Prompts" - return f"Parti Prompts ({self.subset})" - - def __len__(self) -> int: - """Return the number of entries in the benchmark.""" - return len(self._load_prompts()) - - @property - def metrics(self) -> List[str]: - """Return the list of recommended metrics.""" - return ["arniqa", "clip", "clip_iqa", "sharpness"] - - @property - def task_type(self) -> TASK: - """Return the task type.""" - return "text_to_image" - - @property - def description(self) -> str: - """Return a description of the benchmark.""" - return ( - "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " - "ranging from basic to complex, enabling comprehensive assessment of model capabilities " - "across different domains and difficulty levels." - ) - - -# Category-based subclasses -class PartiPromptsAbstract(PartiPrompts): - """Parti Prompts filtered by Abstract category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Abstract" if subset is None else subset) - - -class PartiPromptsAnimals(PartiPrompts): - """Parti Prompts filtered by Animals category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Animals" if subset is None else subset) - - -class PartiPromptsArtifacts(PartiPrompts): - """Parti Prompts filtered by Artifacts category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Artifacts" if subset is None else subset) - - -class PartiPromptsArts(PartiPrompts): - """Parti Prompts filtered by Arts category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Arts" if subset is None else subset) - - -class PartiPromptsFoodBeverage(PartiPrompts): - """Parti Prompts filtered by Food & Beverage category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Food & Beverage" if subset is None else subset) - - -class PartiPromptsIllustrations(PartiPrompts): - """Parti Prompts filtered by Illustrations category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Illustrations" if subset is None else subset) - - -class PartiPromptsIndoorScenes(PartiPrompts): - """Parti Prompts filtered by Indoor Scenes category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Indoor Scenes" if subset is None else subset) - - -class PartiPromptsOutdoorScenes(PartiPrompts): - """Parti Prompts filtered by Outdoor Scenes category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Outdoor Scenes" if subset is None else subset) - - -class PartiPromptsPeople(PartiPrompts): - """Parti Prompts filtered by People category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="People" if subset is None else subset) - - -class PartiPromptsProducePlants(PartiPrompts): - """Parti Prompts filtered by Produce & Plants category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Produce & Plants" if subset is None else subset) - - -class PartiPromptsVehicles(PartiPrompts): - """Parti Prompts filtered by Vehicles category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Vehicles" if subset is None else subset) - - -class PartiPromptsWorldKnowledge(PartiPrompts): - """Parti Prompts filtered by World Knowledge category.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="World Knowledge" if subset is None else subset) - - -# Challenge-based subclasses -class PartiPromptsBasic(PartiPrompts): - """Parti Prompts filtered by Basic challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - # subset can be a category to further filter when challenge is already set - super().__init__(seed=seed, num_samples=num_samples, subset="Basic" if subset is None else subset) - - -class PartiPromptsComplex(PartiPrompts): - """Parti Prompts filtered by Complex challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Complex" if subset is None else subset) - - -class PartiPromptsFineGrainedDetail(PartiPrompts): - """Parti Prompts filtered by Fine-grained Detail challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Fine-grained Detail" if subset is None else subset) - - -class PartiPromptsImagination(PartiPrompts): - """Parti Prompts filtered by Imagination challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Imagination" if subset is None else subset) - - -class PartiPromptsLinguisticStructures(PartiPrompts): - """Parti Prompts filtered by Linguistic Structures challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__( - seed=seed, num_samples=num_samples, subset="Linguistic Structures" if subset is None else subset - ) - - -class PartiPromptsPerspective(PartiPrompts): - """Parti Prompts filtered by Perspective challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Perspective" if subset is None else subset) - - -class PartiPromptsPropertiesPositioning(PartiPrompts): - """Parti Prompts filtered by Properties & Positioning challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__( - seed=seed, num_samples=num_samples, subset="Properties & Positioning" if subset is None else subset - ) - - -class PartiPromptsQuantity(PartiPrompts): - """Parti Prompts filtered by Quantity challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Quantity" if subset is None else subset) - - -class PartiPromptsSimpleDetail(PartiPrompts): - """Parti Prompts filtered by Simple Detail challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Simple Detail" if subset is None else subset) - - -class PartiPromptsStyleFormat(PartiPrompts): - """Parti Prompts filtered by Style & Format challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Style & Format" if subset is None else subset) - - -class PartiPromptsWritingSymbols(PartiPrompts): - """Parti Prompts filtered by Writing & Symbols challenge.""" - - def __init__(self, seed: int = 42, num_samples: int | None = None, subset: str | None = None): - super().__init__(seed=seed, num_samples=num_samples, subset="Writing & Symbols" if subset is None else subset) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 04d226f6..61550698 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,10 +1,10 @@ from typing import Any, Callable import pytest -from transformers import AutoTokenizer -from datasets import Dataset -from torch.utils.data import TensorDataset import torch +from transformers import AutoTokenizer + +from pruna.data import BenchmarkInfo, benchmark_info from pruna.data.datasets.image import setup_imagenet_dataset from pruna.data.pruna_datamodule import PrunaDataModule @@ -80,3 +80,19 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar assert labels.dtype == torch.int64 # iterate through the dataloaders iterate_dataloaders(datamodule) + + + +@pytest.mark.slow +def test_parti_prompts_with_category_filter(): + """Test PartiPrompts loading with category filter.""" + dm = PrunaDataModule.from_string( + "PartiPrompts", category="Animals", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["Category"] == "Animals" for aux in auxiliaries) diff --git a/tests/evaluation/test_benchmarks.py b/tests/evaluation/test_benchmarks.py deleted file mode 100644 index d9345874..00000000 --- a/tests/evaluation/test_benchmarks.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for the benchmarks module.""" - -import pytest - -from pruna.evaluation.benchmarks.base import Benchmark, BenchmarkEntry, TASK -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry -from pruna.evaluation.benchmarks.adapter import benchmark_to_datasets -from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts - - -def test_benchmark_entry_creation(): - """Test creating a BenchmarkEntry with all fields.""" - entry = BenchmarkEntry( - model_inputs={"prompt": "test prompt"}, - model_outputs={"image": "test_image.png"}, - path="test.png", - additional_info={"category": "test"}, - task_type="text_to_image", - ) - - assert entry.model_inputs == {"prompt": "test prompt"} - assert entry.model_outputs == {"image": "test_image.png"} - assert entry.path == "test.png" - assert entry.additional_info == {"category": "test"} - assert entry.task_type == "text_to_image" - - -def test_benchmark_entry_defaults(): - """Test BenchmarkEntry with default values.""" - entry = BenchmarkEntry(model_inputs={"prompt": "test"}) - - assert entry.model_inputs == {"prompt": "test"} - assert entry.model_outputs == {} - assert entry.path == "" - assert entry.additional_info == {} - assert entry.task_type == "text_to_image" - - -def test_task_type_literal(): - """Test that TASK type only accepts valid task types.""" - # Valid task types - valid_tasks: list[TASK] = [ - "text_to_image", - "text_generation", - "audio", - "image_classification", - "question_answering", - ] - - for task in valid_tasks: - entry = BenchmarkEntry(model_inputs={}, task_type=task) - assert entry.task_type == task - - -def test_benchmark_registry_get(): - """Test getting a benchmark from the registry.""" - benchmark_class = BenchmarkRegistry.get("parti_prompts") - assert benchmark_class is not None - assert issubclass(benchmark_class, Benchmark) - - -def test_benchmark_registry_list_all(): - """Test listing all registered benchmarks.""" - all_benchmarks = BenchmarkRegistry.list_all() - assert isinstance(all_benchmarks, dict) - assert len(all_benchmarks) > 0 - assert "parti_prompts" in all_benchmarks - - -def test_benchmark_registry_get_nonexistent(): - """Test getting a non-existent benchmark returns None.""" - benchmark_class = BenchmarkRegistry.get("nonexistent_benchmark") - assert benchmark_class is None - - -def test_parti_prompts_creation(): - """Test creating a PartiPrompts benchmark instance.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - - assert benchmark.name == "parti_prompts" - assert benchmark.display_name == "Parti Prompts" - assert benchmark.task_type == "text_to_image" - assert len(benchmark.metrics) > 0 - assert isinstance(benchmark.description, str) - - -def test_parti_prompts_iteration(): - """Test iterating over PartiPrompts entries.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - entries = list(benchmark) - - assert len(entries) == 5 - for entry in entries: - assert isinstance(entry, BenchmarkEntry) - assert "prompt" in entry.model_inputs - assert entry.task_type == "text_to_image" - assert entry.model_outputs == {} - - -def test_parti_prompts_length(): - """Test PartiPrompts __len__ method.""" - benchmark = PartiPrompts(seed=42, num_samples=10) - assert len(benchmark) == 10 - - -def test_parti_prompts_subset(): - """Test PartiPrompts with a subset filter.""" - benchmark = PartiPrompts(seed=42, num_samples=5, subset="Animals") - - assert "animals" in benchmark.name.lower() - assert "Animals" in benchmark.display_name - - entries = list(benchmark) - for entry in entries: - assert entry.additional_info.get("category") == "Animals" - - -def test_benchmark_to_datasets(): - """Test converting a benchmark to datasets.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - train_ds, val_ds, test_ds = benchmark_to_datasets(benchmark) - - assert len(test_ds) == 5 - assert len(train_ds) == 1 # Dummy dataset - assert len(val_ds) == 1 # Dummy dataset - - # Check that test dataset has the expected fields - sample = test_ds[0] - assert "prompt" in sample or "text" in sample - - -def test_benchmark_entry_task_type_validation(): - """Test that BenchmarkEntry validates task_type.""" - # This should work - entry = BenchmarkEntry( - model_inputs={}, - task_type="text_to_image", - ) - assert entry.task_type == "text_to_image" - - # Test other valid task types - for task in ["text_generation", "audio", "image_classification", "question_answering"]: - entry = BenchmarkEntry(model_inputs={}, task_type=task) - assert entry.task_type == task diff --git a/tests/evaluation/test_benchmarks_integration.py b/tests/evaluation/test_benchmarks_integration.py deleted file mode 100644 index e7198f07..00000000 --- a/tests/evaluation/test_benchmarks_integration.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Integration tests for benchmarks with datamodule and metrics.""" - -import pytest - -from pruna.data.pruna_datamodule import PrunaDataModule -from pruna.evaluation.benchmarks.registry import BenchmarkRegistry -from pruna.evaluation.benchmarks.text_to_image.parti import PartiPrompts -from pruna.evaluation.metrics.registry import MetricRegistry - - -@pytest.mark.cpu -def test_datamodule_from_benchmark(): - """Test creating a PrunaDataModule from a benchmark.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - datamodule = PrunaDataModule.from_benchmark(benchmark) - - assert datamodule is not None - assert datamodule.test_dataset is not None - assert len(datamodule.test_dataset) == 5 - - -@pytest.mark.cpu -def test_datamodule_from_benchmark_string(): - """Test creating a PrunaDataModule from a benchmark name string.""" - datamodule = PrunaDataModule.from_string("parti_prompts", seed=42) - - assert datamodule is not None - # Limit to small number for testing - datamodule.limit_datasets(5) - - # Test that we can iterate through the dataloader - test_loader = datamodule.test_dataloader(batch_size=2) - batch = next(iter(test_loader)) - assert batch is not None - - -@pytest.mark.cpu -def test_benchmark_with_metrics(): - """Test that benchmarks provide recommended metrics.""" - benchmark = PartiPrompts(seed=42, num_samples=5) - recommended_metrics = benchmark.metrics - - assert isinstance(recommended_metrics, list) - assert len(recommended_metrics) > 0 - - # Check that metrics can be retrieved from registry - for metric_name in recommended_metrics: - # Some metrics might be registered, some might not - # Just verify the names are strings - assert isinstance(metric_name, str) - - -@pytest.mark.cpu -def test_benchmark_registry_integration(): - """Test that benchmarks are properly registered and can be used.""" - # Get benchmark from registry - benchmark_class = BenchmarkRegistry.get("parti_prompts") - assert benchmark_class is not None - - # Create instance - benchmark = benchmark_class(seed=42, num_samples=3) - - # Verify it works with datamodule - datamodule = PrunaDataModule.from_benchmark(benchmark) - assert datamodule is not None - - # Verify we can get entries - entries = list(benchmark) - assert len(entries) == 3 - - -@pytest.mark.cpu -def test_benchmark_task_type_mapping(): - """Test that benchmark task types map correctly to collate functions.""" - benchmark = PartiPrompts(seed=42, num_samples=3) - - # Create datamodule and verify it uses the correct collate function - datamodule = PrunaDataModule.from_benchmark(benchmark) - - # The collate function should be set based on task_type - assert datamodule.collate_fn is not None - - # Verify we can use the dataloader - test_loader = datamodule.test_dataloader(batch_size=1) - batch = next(iter(test_loader)) - assert batch is not None - - -@pytest.mark.cpu -def test_benchmark_entry_model_outputs(): - """Test that BenchmarkEntry can store model outputs.""" - from pruna.evaluation.benchmarks.base import BenchmarkEntry - - entry = BenchmarkEntry( - model_inputs={"prompt": "test"}, - model_outputs={"image": "generated_image.png", "score": 0.95}, - ) - - assert entry.model_outputs == {"image": "generated_image.png", "score": 0.95} - - # Verify entries from benchmark have empty model_outputs by default - benchmark = PartiPrompts(seed=42, num_samples=2) - entries = list(benchmark) - - for entry in entries: - assert entry.model_outputs == {} - # But model_outputs field exists and can be populated - entry.model_outputs["test"] = "value" - assert entry.model_outputs["test"] == "value" From 975adb3fdb3845eb363367902162a7b7269edd5a Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:18:22 +0100 Subject: [PATCH 3/9] fix: add Numpydoc parameter docs for BenchmarkInfo Document all dataclass fields per Numpydoc PR01 with summary on new line per GL01. Co-authored-by: Cursor --- src/pruna/data/__init__.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 820d1262..86e36cd4 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -107,7 +107,24 @@ @dataclass class BenchmarkInfo: - """Metadata for a benchmark dataset.""" + """ + Metadata for a benchmark dataset. + + Parameters + ---------- + name : str + Internal identifier for the benchmark. + display_name : str + Human-readable name for display purposes. + description : str + Description of what the benchmark evaluates. + metrics : list[str] + List of metric names used for evaluation. + task_type : str + Type of task the benchmark evaluates (e.g., 'text_to_image'). + subsets : list[str] + Optional list of benchmark subset names. + """ name: str display_name: str From 6b0f4f7182b5fd39de462bbfb9fd396ab84efb51 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:20:59 +0100 Subject: [PATCH 4/9] feat: add benchmark discovery functions and expand benchmark registry - Add list_benchmarks() to filter benchmarks by task type - Add get_benchmark_info() to retrieve benchmark metadata - Add COCO, ImageNet, WikiText to benchmark_info registry Co-authored-by: Cursor --- src/pruna/data/__init__.py | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 86e36cd4..315db752 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -192,4 +192,70 @@ class BenchmarkInfo: metrics=["clip", "fvd"], task_type="text_to_video", ), + "COCO": BenchmarkInfo( + name="coco", + display_name="COCO", + description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", + metrics=["fid", "clip", "clip_iqa"], + task_type="text_to_image", + ), + "ImageNet": BenchmarkInfo( + name="imagenet", + display_name="ImageNet", + description="Large-scale image classification benchmark with 1000 classes.", + metrics=["accuracy", "top5_accuracy"], + task_type="image_classification", + ), + "WikiText": BenchmarkInfo( + name="wikitext", + display_name="WikiText", + description="Language modeling benchmark based on Wikipedia articles.", + metrics=["perplexity"], + task_type="text_generation", + ), } + + +def list_benchmarks(task_type: str | None = None) -> list[str]: + """ + List available benchmark names. + + Parameters + ---------- + task_type : str | None + Filter by task type (e.g., 'text_to_image', 'text_to_video'). + If None, returns all benchmarks. + + Returns + ------- + list[str] + List of benchmark names. + """ + if task_type is None: + return list(benchmark_info.keys()) + return [name for name, info in benchmark_info.items() if info.task_type == task_type] + + +def get_benchmark_info(name: str) -> BenchmarkInfo: + """ + Get benchmark metadata by name. + + Parameters + ---------- + name : str + The benchmark name. + + Returns + ------- + BenchmarkInfo + The benchmark metadata. + + Raises + ------ + KeyError + If benchmark name is not found. + """ + if name not in benchmark_info: + available = ", ".join(benchmark_info.keys()) + raise KeyError(f"Benchmark '{name}' not found. Available: {available}") + return benchmark_info[name] From 56f2167391d00b96ba08d0968e6c0be588904a83 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:25:10 +0100 Subject: [PATCH 5/9] fix: use correct metric names from MetricRegistry Update benchmark metrics to match registered names: - clip -> clip_score - clip_iqa -> clipiqa - Remove unimplemented top5_accuracy Co-authored-by: Cursor --- src/pruna/data/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 315db752..ce404911 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -143,7 +143,7 @@ class BenchmarkInfo: "ranging from basic to complex, enabling comprehensive assessment of model capabilities " "across different domains and difficulty levels." ), - metrics=["arniqa", "clip", "clip_iqa", "sharpness"], + metrics=["arniqa", "clip_score", "clipiqa", "sharpness"], task_type="text_to_image", subsets=[ "Abstract", @@ -175,35 +175,35 @@ class BenchmarkInfo: name="drawbench", display_name="DrawBench", description="A comprehensive benchmark for evaluating text-to-image generation models.", - metrics=["clip", "clip_iqa", "sharpness"], + metrics=["clip_score", "clipiqa", "sharpness"], task_type="text_to_image", ), "GenAIBench": BenchmarkInfo( name="genai_bench", display_name="GenAI Bench", description="A benchmark for evaluating generative AI models.", - metrics=["clip", "clip_iqa", "sharpness"], + metrics=["clip_score", "clipiqa", "sharpness"], task_type="text_to_image", ), "VBench": BenchmarkInfo( name="vbench", display_name="VBench", description="A benchmark for evaluating video generation models.", - metrics=["clip", "fvd"], + metrics=["clip_score"], task_type="text_to_video", ), "COCO": BenchmarkInfo( name="coco", display_name="COCO", description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", - metrics=["fid", "clip", "clip_iqa"], + metrics=["fid", "clip_score", "clipiqa"], task_type="text_to_image", ), "ImageNet": BenchmarkInfo( name="imagenet", display_name="ImageNet", description="Large-scale image classification benchmark with 1000 classes.", - metrics=["accuracy", "top5_accuracy"], + metrics=["accuracy"], task_type="image_classification", ), "WikiText": BenchmarkInfo( From 8f36c26c00120239ab725b3d08e06a350d7045c9 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:34:01 +0100 Subject: [PATCH 6/9] feat: add OneIG Text Rendering benchmark - Add setup_oneig_text_rendering_dataset in datasets/prompt.py - Register OneIGTextRendering in base_datasets - Add BenchmarkInfo entry with clip_score, clipiqa metrics - Auxiliaries include text_content for OCR evaluation - Add test for loading and auxiliaries Co-authored-by: Cursor --- [conflicted 2].coverage | Bin 0 -> 69632 bytes [conflicted 3].coverage | Bin 0 -> 69632 bytes [conflicted 4].coverage | Bin 0 -> 69632 bytes [conflicted 5].coverage | Bin 0 -> 69632 bytes [conflicted].coverage | Bin 0 -> 69632 bytes src/pruna/data/__init__.py | 9 ++++++ src/pruna/data/datasets/prompt.py | 47 ++++++++++++++++++++++++++++++ tests/data/test_datamodule.py | 16 ++++++++++ 8 files changed, 72 insertions(+) create mode 100644 [conflicted 2].coverage create mode 100644 [conflicted 3].coverage create mode 100644 [conflicted 4].coverage create mode 100644 [conflicted 5].coverage create mode 100644 [conflicted].coverage diff --git a/ [conflicted 2].coverage b/ [conflicted 2].coverage new file mode 100644 index 0000000000000000000000000000000000000000..5621a05297e99ce9024bf183aabba7fd8f89653b GIT binary patch literal 69632 zcmeHQ3v?V;dA_qd``(>dy)4PHt+6Cqk}YX1$&MsHq?K*iO_Si*j$g6m)o3+Zt-UYr zuI$*s&TJBJfC6a`1Sr9T_u<(T4v^9|QGuS4q$F@!dPvhI2F?Ko@-Psvsgt<>e|JWj zm2D2Rci|-5Inw{nsFfdAsk7OZ7zI0Y5kcoB_@N zXW&npfx>pbtG2#gxbtLQ9ZzWatQys_MH$$%V`TW25#^TQjW>=c#d}JdSAnCmQyEsW z>0?Sp%PN!cgr>w(6Y;2;kEdcveo8Y#Pv^7=?C9Vgbafc8i^j*ZNPGff<+T_n8yyn$Aw3_v2YL6`j&@9f~@e!e|=&ye(f$Z8n_@<L8xI;Cj0Y0>GtHgVMib|7iy0pwXz16X62ps`|1!2(A$Khh1t ze|Lr3RomDooIXMe3epo9kLPoyWC(cDJ>*eixLSXJxQw!##V}% z1EtwoXF?_!Y8ES+k~RW@r+?JahR%Wy%6K9@URu^@iOFeDL7*)#gH2ASqF7aOn1*Qj zsFqb@T9@&(*ev|93lbEaZjZBer14rsXSN7Lx@S+lg?gu}wxL0|(?g4qkz0mxl@uW} zC$G|snvas2nw!p=4ask7B+VFet)y8;j3REE{Wj|I>3Yh3 z5*mH{x0kq#cjjEaK|9`gR82r@l2PMXsIt%o7kdB&15T(ASRGGK=V|%HOIkd!@0s(r z2rzeTau1=}X>h3CH>Y#CICQJnZROUQca_#DHHWQGu2mUNrxTi*D)tQ!TkYIAX+(wo;$W&IOzKTcFaJ^$SjtNi{WL^jFvo7g3?@$I6Lzij4?FR9c}@ z=Aeqs!!$vSr`Ted_+1&`1W}XfM6tSKrdSY!2YF^97lbV7-s57)+9aD2k zG^@cw=u6vM?Z2%UU~WaaQQ|DwjYGLYqu{Dtxl$-h8yrV$&AwPRzv8h<7VHhR#+Ly4x(2#>ypSrJe2&QHh3~&ZG z1DpZQ0B3+Rz!~5Sa0WO7oB_@NXTW5@Z(l5!s{l5yy}@HF0yrQ5e*Yhi3i*=s9;qkr zRNzFQ#s4jT%3tmKvTv97Ij`=m_k7W_-TfQ)8TS^~IhXE|ogZ`dJHG807he|ND-PO! zVt=c>&GxwMgsnmNs<4y1K>jbg`VY2?q%a)Xm4l5Rp$YY9d;&HHYN=dai>E@H)6waq zmdfWsnQS_iRg+2BAqnTxDbsi8E*wE*O$RcR*S`R`RsADvAj$+B9VA1o{vPp#^^GzF~l|<7Wi!~0N>02K9^K; zQ{eCZW&m8r05Gp5G}y+IJsv_Ck7Cg|&;;Ncn1Wllh%5#4O$^Xs>)d2K2IV~))AEsI zdSW`Eq5HcOFc4y3Kug77YiKB^9>t=*tC1~iBCSpU?C27JWl{1BZ-ECb5rI8d(Zfha zVllvvvLj;(wzMU*Y{<|N*e{;S!KT@ymIZ6^MPO}D`K;ygU|!84Z}Eomn##m88tybj zQwz&$Y9c;42^;p2z(e(5YG3&XjHir!`(P_w2ezW+vxWPgBB#Ni3il32^63Z+y0kQ!36|9i{KN+A zx6d4p$zWz@c~TS8Qn0Nbclm>%LlPKb(TbGP2EPK4LkkgT@{eg*7?Ogmg8{I`VjNp) z3&y?4Odjw@{eaJ+e=Nl}^JK3N;MXw7lSDcOqlH)qwo>P5Q@+>B7#S^41`zjGEu;{A{3v_}I$z|Zzgs}418S!3%Uc3PKeKvq+(PbI%7Up*d0Nu`j`64Dv z`Gf5?I{*J5&i`fkgXBfw7V<86htwpA;5_g4R-{u+PUcDudZ_OA%}`yLUoFf1@g;5Fv| z?V4D)sT}wv^Z&M}Sh&6%RxI=X)^gdf6u0FE#ljHdbXee*-7FTGS-_j~|CR${VI2#= ztIhwLN5#Shw%`^nn#RP!CKl*MqjAOjf9YNt8Y|1kKUwBw7>?6hPaj{x0;p9bWlGx|F2@zlvCWR%>OGV z#lkigoRm)f=Fk5t%4^1)YX+yp!p<_KreyvvvuZ_3X}jwDUt%?mEwwc}{|~V0A4~Dg zJn`=n3u{>9$rbZ|A8X{<`M-C+SQudGv!$u<>=Fx|EU+ovX4~d2PbQl4f7f2I(96te z3mZ;VER3)KHpZ$~&i@^(`YOXh>HJ@0#d`^Q@dD=mc2-?h6eiCStIovm^~!kx|NXz; zca4IF=L~QLI0Kvk&H!hCGr$?(3~&ZG1DpZQ!0U~)-2X4guMqg-2WNmYz!~5S za0WO7oB_@NXMi)n8Q=_X1~>z6KnCpgfCoSSA93DJXP)nQKH}Ny8T2gk_}wqMzvF(y{XzF>_kQ&sEhf}PbfORt#75v1x^M~y zzDrkV+6+OrY$oCKxgq;&zMJ6|eMp5(7u+`Z5!+1=M%hI4CyI~k-5cS?#73e&GuSUc zgjM9S@4WBL5qvZ648%uG-_ncW!zh9k!w{uqnCLI~AByPD!kf2+Jzv2iOM9tur@0Eq||uRw<3AAQ!rbR(V3bsM8b*DH8wloOx8#@h2h!gZGB7`wIkq9~G0*&WzC=^^@0;kj=d;|wZ!DEb& zZ_!&p>W$$vcv55*lqDI2JZ~G^0B)VyK=iLRj`duWiFonkOP5YPblwr9iJ&nRGqMix zA=!8n^9hbc?)4DAay`*)E7rl;?sY`JRDUr-BSX+I211d+6TQgAj8M52!Y*G+^oL*d zzverCVHg0L2H@V70Vp4r>ye?$8v(uufN{aBeuvx-VU&I{bISSVAsw7q)CV`#_L1<< zgr2h(eCJ)@8}%6@gcn_hPp>24r|ms6mu;^+97I2i=fT=F@O0H0qW|NEg;#3Nz7I`c zv|#OOe6X76FP@yKKMw`tvJtVa7asKVLVgV@llo=H#aCbQK@=AWZtcNhBlSS(yInaW zJPVO7oLq_tu%-aO72}2tEzKS}oYCA5_H=`Z$!@?osXsY$q4w;BSD(4G5eu-v_-q_- zw1_2EuYx!$RuTPY^$!g}8ZZ^2P!~dHLvZ&=a(S!%>?OBRhQw7}@MOG;=wEme6LE4# zMDSIJCqfw816YEffD#%39>F`fA;y7DNce^Ni?4$3aFZ4V3}hVkAleOwA$Vg41P^!6 zTp_0-`VbU4Yq<+ouY>?ARucVJ2=~fT3ZraYK}lRe^i%J?{lPl&+J%S?brJyK-#BJ@ z8+)LB%Bd^lfhu@n4+qJ2JQ&2O>qM_?2l(afM6ZXqO-*kl9YpAFgL~OFNc8FFzm)wb zS@fg+cfp@fr{fpevUV-QjD2v$__dubp{CXlsUx9nF-&@Ews~KfiuCpt_M{ivMteEX1_E z<9>G&0yGgl+_e`)~1Vem!=S|KDh1$h2uP-mgE=<*E;61G4g;mgnb z(_7#+92Z|Tngn`9t${%Hby{NZ_H16jEg*Lc49-X zdWK%ZfkBK@yfjmW=cEq)U$K$!!6k^fehJZ!Jp_^A4`71mNwK|Ds8?_wW5WeJwSf+s zAgwK7r`B65#AvK5G1^y0d90P&>J>zfb>%{O2tCB z-?9*zn|k7gY6=d?^%%FF=>OqgXon^>zn~7|))D<{F+$!_te}8k3)aH@rdp!^`q>J& z2Ee`ra3`|>nvB={;nM;b#sjOOJ4~9$Pw$}Dw_UzmJfT-8#8eG3bfAXlubuMPFx^PP z5I&;^(mr=$!-R+N8qPo~a8_)nfkV7=7NkKV)@VHHbcKe6Q-ui5YH(Dl7FJY&-d#oX z6Ig%F!ast!$O9gtV?r}{hnA#^uHH%rJ5>pZyWL67UA}3MkbB63wJ*qo+?%{fFMzjD zh#cf|ME~CHUvCk}+k9cay8P zC@!KWA%GAj!9`BEvDQg+-%`Kw@*|gDI)3KDk)aUO&a)RT2^9`_+TtMkUki#x&NV;- zQYpeM_ysrp?LDO3em+EwdGC$dG2gd94S(%aNPm*(Kq?_KC?5n~2@HPRn7HDBnG4kk z*+bv#W5ixJBT{b!T4!{EM=(qjAxQ2N?!hwP7MwOPEZGFNAbLF2ghT`&gp5dyM7|`y zEWaTCSbkRizWiPJG5H(vKgeH}zaW1~en9?!{8#cj7w*=>8H|Cp{tki!>wso%BWNv(m?<4@vh)@0RY7 z3epKFCnY3JIwHcIQH)l#Rl99-cCXMi)n8Q=_X1~>zp0nPwtfHS}u;0*lX z8NivXFfu~b=FL=X+C2SHc+*GJyq-0QMGn0RRaT5_4iZN z*GJWL*HN`*4OOdGQ`Or`RZkC9-Q85JT18bTL{(Q8Rh^wwb#zd*awSzOR#4U6PE}hQ zRjsX5EniO6vSn1Yv{2RDOjT18RZEvrr6^Q2Hd3`@2~~?1Q?-cB|AmGIdbMyNRrU2$ z)zwi|TT9h~1yt44P*q(`RaF&Lm6cRgR8SQRQYFh&NfK3o09AfJRX!h8UN2Q14^?hA zRW27*PA63k2UVg-mEBI2%_hPc0R8+Q*U9+78Q=_X1~>zp0nPwtfHS}u;0$mEI0Kvk z&cGX&0sQ=*&;Q@J`O8zr8Q=_X1~>zp0nPwtfHS}u;0$mEI0KvkWB~X7Zk9sxF%bvG+ zLhi@h8Moi{H?B81pLgCN-bZG{9{c0=al2^yw6n$Wb=zh~)ZwrZ;R68r2Re?31&zha zDAhzPosH+GlDSYmosCXKqUmHNp3q9K;Qfvp;5Q#CSy(K6_v-8U^}7>d;RY6r7%!>J zWVJ}~m7Uo%yeu`P<$(4PRp zDgY&Kgw18ts0Qfee&4KG_Sf;Us(J=JmUju;5@KP0Su@&VY&|QcEG0LWSH}}tWKzvx zOR{cEENo(&9H`vWQ7w{6V_8~zKrF0dY~aYWn#yBKHBg3{edO@b$V7Z{ayqBcs@Biy z1>lxE$f)DseIKh{$by_mBY1s*d(+= z5ll)vmG1{rt21IDT1Hc|`E45PEknVYI&!3>VD*$CXTS!~)y=B;vlOf;H8Ba$tIE(` z#4`CKkz{&8OMrn;8GN5iXH|giDnn~JnNV|65jCGr^@4@Ya#(hp8B#E==3v^R#&BrVUWTem@v~Yi4&%z>V53zN z3pX*&1!^OcoldC{JOwLloKl*y)JSv8r2rXS9M0z>AqQRB>liQoTUYQLA*FS?sukGm4Cfb;Xt zVaN9!M;u=9)ADEJZT5TNllq+G3fv#)_J6~_-!J+;;oIsx?|r{_z2}FXd&HYPeeS2+ z_#>qJ;0$mE{$C8l#KLwKEr^zap`R9ts)@T+L6$`yAa1&`db`6=pn^+B%=(1a}XTVb!DNQG^f2d@&eDcOw*Tb&kDycEn zPU7hl&?X-i3;S4Pc&Q;Mu(>=~h`>W}x>+nVvncg+og3V3VKsd<+fCH8rT4N%o{T483I#s<_KSr979^NHTL|{BD$GVd zXm%qtyTw9?h0g|VPCbfMTVYk&4A^4-WX|FjuxdCAbc!2+DRA^KR+S};IZ32d{5+zO z@la+;O-&@UY{<|NGfQieVquhFmd>|Br`xdTk)(xwF{^btBQNG#g9YW(gqlSH7nRvm zCZ5p}Fh~Sb4Xk?cYosckn#7qZ{XC?;%*isBA0b%_S*>(lBU$P+e9ouEaCtaR+krY( z>urWXON&^&ODuGqHY<-CzyIG(-bv)o$=l(Z`R|c>0#5}_1X}#x@~8aOzAyWB zd7ty@-g?g$J=@*Cai4K-ah-GNF4_4pXTRgyj&bp2@x9`p{U`Rf+S_c8+fLXTgs%!a z$qNAa`yU6}MN$}Mu)twiYPGIrO|$|2O=Z9@S!o(?1uNH=!HQ+gY#iPmQ5GAP5;wXG zYz(nYhXsCH3&1xsfH%ib`Z71@<0=SZ(s^;;UcmW&^IwaH>TrxH5=3N z_%Q%XZn2NwrGSAD0|Q3GWQ==uHL`^@h7O}k0G36`GrXNQ;7Ke7_)&Ib%5PTP}%I0xpZDoZ?<(p;D^= z1KSvILJgqEY4EAh{4XGC!Lph$7sX;SnAusL)RZh#9+JQii&mtRwyQ2w9t?mj7US4b zTeAz5qkh0=(La{rn|ZR=2k>hc+~Nl8P6pVN zZnJG0bpa~-Of;8+_c#H*mzC2NHpUzPJi-9j7^_Yl&uH0<3ZGbOu&#)M)dM0}V9{3@ z7K&dcHouQKYG=fI33~AYoZn{ycotoj0dHY`hXBy+44AiI;{1P!m?QETc@W;%cS33m zJRFDyJpNDkSNk6IrF~xSgWelGFL~bO>2N>dzRC5B>x8Sy`5EVpj$b<7@8}huhRFQj z3~&ZG1HT&slUtz~V>sox)^{9sZvi_jCM8BoYkjwh0W3pB$2gXfyLs1#X`vjlKlfj0TWvu;7YVgt$EHSFClZXk)T24oh?%og{rv4JGr3#M2sl3v4 literal 0 HcmV?d00001 diff --git a/ [conflicted 3].coverage b/ [conflicted 3].coverage new file mode 100644 index 0000000000000000000000000000000000000000..8d56fae2eb7087f8b524e0fdd5847e936ad0a0a9 GIT binary patch literal 69632 zcmeHQ33MDsnXaDcxo5gY@*&Bxt+pgvk}YX0$&M{w8rfD%HVHlwU-2<&jas9z=km&|C(;8 z#|IB~Yv3hRKk5JbsQUiDzPf9srf=JNOR}h|W7&LKE2@1YOeBd!Rh1CpgZ~=%w+;`S z$kqjD$##CdQy=L*xjRI6kf8gWgzgV*rd`2@LW=@F2&R2+4?OJAd~tXn4#9w6KrkQ} z_~T}vv^D6dYiN+}K2g+0QhG73#r1qyhBj;)8oYf-y?t=~EkkPgp4#qL;ppm82i1J` zh?>*$>R2+RtI5o0GOiVqnS@##*R9YK1$`7dI=Ba29R}=X@$oE{9EDg#JpmB8d@`-& zkEsXsW1TF3X|1RqEtYQph@MPjFofD}M(AGjY98_M$yG+TriG=ak=5H)Do$ zpU%>x<>T-ds7MXmm`=(~V*nJ-=10-{k-V0PkL!g_Rhvp-JPUpvEtXT8&!!{AteQ$@ zSTZw(WHFh|sQOVoK2g+1ubIG3B+WX2JbP*YYZwzWT#jklvavHA`5H0J;AZ48a4wUz zXq}FTrXe!(IWv(S0m=_dq;r@#5Tk6kn7v-i>%}c*&BXrJd2@}toK9~=c3PBUj+SFh zWRiy_^r_76RHtKhvSehkohl{*O2EQ^yRLtRM4(jaM+=8iFy4x36U8jMjzQ+fdO!s4 zt?_#5nwq5h4zYrQ^u$Jz#lqByucFXQ-M&uI!ROa6H&{z53klYuR7%Z`u_}YH)pF)Q zX|~syh((5$$BL$A&4A$PAGEZgv*3d|lFE)$mNiyl3OZB}XbY@hV-uM;R+R##A)Y;~ z=e301Z9XkG3xDW>1VyLE=dK%Sx?a(lDgv>dsZ)Qc!R@JQY?SWyv0`NAmZ>}yMaas@ zYc!+QqqMFSCh}H8@|zk-E5>vyY1I+)2_&b>u8Kt9o#`rYMMv=4yPwmF9JJ?naoqYJSflIGD@#H8A$m z(CFj8v%+P(Gwt%t`mwIVS_)c|oR-W(m4!CA+ykf>a8!f9+DLYy$jUEXvf_z-&$Pc~ zfaz~MuoN?D<|41HzE*GZG%c# zfGT?$%heq-#fErfT)zsg+#7Awih41b)>RC`oM8p)h*nVJ zc^w`?U)s^;{7uCGb1T+^5@*R?9LklNBv0LvB~od^tc>M$CRTt>J*`=8o0pD?LUP5! z>GHA}8eACd*vkoWXQKlz%NDS4E9c-ynFTA9Io9!aguqLQUL^1*4#9w6KrkQ}5DW+g z1OtKr!GK^uFd!HZ3Yx=R9hC#bYUzoQ-wnmjG6D_jmXDAyp1~me8~GSFg_m3a^qGqQE(n0%|A(W7UR2(r^oE`d z9S^kzza7j3X9m6$*zW&{-|#p1zTn&H{k8XD?~R_b9>YW3f9_uC`kremB{9k#qo3@GNKjqNKs1^U~|Z5JK*111$^d3 z3{8QRkv6b$Qx&Y_^+d8z%pXG=i>hQJ7E5N5#aIk%>{|#n2Kc7K2EVx#;9EGr7t&f` z9Q@tW0)VSI02cL>4%>M0$08_`aV$D}n*n?+S8y8_u?2v>fde{hof}IgpuFc3dNG#H zj!vX>bbq@F1|l2`=$Qm;4UH7E!&ua}H}Qo{WwlX&-8Ua#d6YcUTi`)Y#bD1>{2-E% zng{Uv_>nOVTiQ~3K4R(^>=(}zVAE_`&x5t(T(GvYde#a>Ft6p2w`5~=P34j~9e0|d zsX5g(HJTh7gAMyg;DH7(wYz!*CNt)~eXy0S2V3##*}{F_(2VDG*ow(&&{!Q9+*&<@ z6UAhTwO_Nr4v+TBtWxD3k5w$7ZO1G?@8OpU%j?}U0e1yITw9a7V+JReanmYnich6r zlPFLTs|8#hO*z9Y(~PI(wj2!THDF*f2Tqs)6a^gyRk(LJR?Nm=(4}Y5Ot`9M;3qa< zzkTjlf`Xabs*{?8o`G%sxXT|59Z0Dd`#JV|8}Fj`1NU@LWzHRZefoRQH2L%zcY$o(9B zwlx*EdjY$P12&`EYTNdC0F{3xY7?VLtiL3Q%?Cy7e-1U?zF8|6?k_Y4mJwI`t^==A25u6pwI_`9KIKGRJzv~f`OM?=J1YT$U z-=WK;TdRRzG5>Fm%cYyDVZ}E8Z>yFKTX9>oUoH)BPKOPC;TE~n!UNu#|F`ayORISR zUTgl}vQI9pmNKP{K` zaU|n6&;REgkV`wOXw5eNpIc>9w)uZ!l}*{^|8ow?rQKB|&_4ff$jPO66>VKL|F2KT zrL9#oXpJ`Oc(q^God3_>DVKUUro!@i)*W(b1xH+4lRJ}_%WP0TJ^!Dp; zj>)CXJUFSG{9QT!uc@vXYpxj{mrJ)*DK!=If6A*Bsif_i^M8fcIJVN()cilhtADJ- zxAG*oTP`ip_<&Hp{S*Pr ze%Cb$5neDL7!V8y1_T3w0l|P^KrkQ}5DW+g1Ou-t2AqEB-2zuj263w0+a9#>*fm`e$6AnZ2I^Y*WEE${fN+a9|WX zh6(vLyA@{Mm`;NyWoALyl0nD|j&*Cnt&?ks@s+0G-V2n-7f!r%@#JIYTw#_77E`q% z8xS9ojW;o$;F#-O1MzFu5W}%}HJt5OO^k~T7h)_j1dU=K6gfOG%3MrJwW}cPqE*Cr z{I%evz`65-0NC6Q_ipTm@^QHV8M?e4;L89Q7t9)`=t>Bqt|XHu-FFQb;LO|=aAVa9 z68)Lfd**!LoCkbkK4XOF+#B%e4J7(QXYb@?$E%Nr(GT-^xNbQ-UAmkYfBO;X)w(n9 zLlYP+T(=A#EF;FxPfRwPgMx9{j9A|X4|@9`zt*W^#%0%q*Io)h6b}h+>cwKC^g`*o zQ+-%E4Ux{DSbzzzrvSiJ^M(U0O&tcDvD^>$_JE189>6(aJUw~7?#%hup1Zgn3$V%f zR2*=$j3v)p3UL-MCC1Mh9vgr(U@By(K7!6h;O^7p@+RZ-MXy&tHEpQu-3$K|?0=r_?Kp^`DD=~O`Dlgy`kUI?WOx?C71Z!;qu#-oS19I;2sQ_wr zU>up9q1SL=5c3o-t(4(8sfYhp9VEJcK4Pw!PmCjvL1g#?m@s-$ZZB2l72Lp{g041s@eGdvFnYo14p1MaAT)Q0I?xTDU+JER|iL}xd`jqXNv!vPMEmCTG< zGY9Us&VlBpfq0>sf^vxsr&q|brr zMiK__89R{n=@SPgJc`$F23m!)azhOq;+-=fO&YOA<56#@EG(R=L~_prN0ph<;u)ZK z%pk^btUqVqAHksLKr4X>&EXxmgaYO1u7$AsYawxWx}h_=WgXsz&V!mF z;CQBW^5Jc6A3Rok(gH8&jgacY4JwezvH%KFw45@eKTJIkMD-9O4FRMmK)VszO^m>T z`smqHmrwt4%}d{U?Q2l7jMFb2>929Y<5m|jJ}9X=Iok+rNUaRF;1}GCxA&3`=eYv)sgWgL|(8DxK$Eil|ptsRm=n!2;SJ0)jgErGfI+KQ|Te+-UP=2BO zRC!+ccjaG{r<89glgeK!Ux0Olk1HQm9#r0?+^>|B<4OV66m(_3vQycj+@!2mRx8Vt zE@ct8A`ZcTU_dY+7!V8y1_T3w0l|P^KrkQ}`2910Q(I|hh^dVmncA>{srBob8XROQ z8f9u=fT?xsm|DA*sWoeuTD_X7RjZil?`LY|N~Ttg!{wx0k7& z9;TKqWhxS3s=J%1t}doJJDFOtgsH`gnd<0Zs=b}5wl<~~En;fnLZ(_Jnn3_MIsd;SXFU_6Ht{NMenlp!~h6bkU>zS&nV`}znre@7zYUWI)X3Su! zww9@y8m7WwrYL1fQJ4yamCUv-(t-3ogGHqrHTEnP|b zXgBN&Xr)bb4xL5AG(g={QeIWgEB~#$2)`}xyz-p#tn$yw)5>?0e}Fv#f1`X^`77nK z$|seNfji<536KUT ze#f=&qYt$_ELOgI_4WMv-BG!8GY>}0msIBRdaV4)&U_YLmYUHEK>HA{*J|1_`beR- z2+$jOy>8S7&B~CL1@!g2UT}ORdTb(7DCQ^P#fiK=3Wf%Gy|DQzLltJCak+F8kKDe> z%;>R*J z!N?6&a6dVA1dx{>kV`u`=j`}IB9Y7_#-JUFVN#Nr;z}^JEGL)ZRWvn~-VOJ_&*6c~t9!S}IjUIXavDzv6! zDXlOb(~8ARA6V$Bh6QM-^?Y25k7IM%8IwzcoNL0UuIDWLB~?g4JXXwWnLl zAq68^0j51#0*6K&Rj9fQKd&c}Fs?iXHrjN#bSvjvU^a63iHsJ*Q?Sxrt(4eC7>oAH zr2)<?sZS=+xRhiG3lT=p2&m)>R4`s%+%xFr_M@${Fva~KOm-cbY(ktz- z={78Sq*?La-R^)}O>two%gcqqK7 zZ>8&dt`YfV`MvTw=Z~Flb+$X6bR2gyN?(y~BQFBv?|$s>kV$Ef!vcqGsnx!kHQEmN zw^jkaVx?)M4XoT$1uM2Svk`cEL{)6qO5DDMU}J!9I&AQpTLHd>1H3hc+S3Ant2qEp z5Ba9YlY5&1d@WaS8yB$!fWCnPx;Y&$YWaj-#E$`Ba*KWZb`=anI2bS+CUe}gy@@ZZ zIds@JA7FWuJk#4N20W>G0KbnP8CUkTG8eRDXL_J(pGuEP5 zf`Xabs*{?Eg~|g87~;{2RMK|Mh06UQu*G8>TWM=*p>kgk@Oku)mH1Yk>Hly2W+xB??m47B$%fUO{0N=;UX&W2E zE&v|l0Bnv`$B*Uod`^Q;taVsd#KG!b87%PVt4s^!FB4ne$K2=S#CrvL`2w8Z?ErWl zU6u)NV}6?i&>b9@w_)P^f4*EG^kKRV-q?3sX$n0aiido`PXw0*o(N%e{@~-v4J5l(fDImVQ(FSc8%W|e05Xqe-G+SC29jON`MY1< zK$2Vr$UK&rZSG%Z14*_IOz~JGy^bks?YCfUAQ|ZegPaT2l^aNQ_3+P`@&=OKO97d4 zS8i?~iA4aKb9+&-fn;|#f8WgwB*R?*&ABzL+(0tiSp^Ff8%Xvq0Si19`fN49zJX+T LG5-`)ZXo$zl?)I4 literal 0 HcmV?d00001 diff --git a/ [conflicted 4].coverage b/ [conflicted 4].coverage new file mode 100644 index 0000000000000000000000000000000000000000..b576ec238dae2bcf55d4a121b8c862a6ef1743a6 GIT binary patch literal 69632 zcmeHQ3v?S-nZ6^9-qMWxitRXwC$m=?*0D%e)oN)tGQ*{jq$vuOr*0(HLvuNAQ1!-Run>r2mUMI-#lD! zBAOSV1M&pzjgM!M_&CJMYcYVxWaCLS zdrUc`9qXV03~PDqwtVphfN1ep3PUJuM)YT(*PXJIfogGK-$FgcFI;rJ46m=$r(KPsZTfUguY&sdrr%h~jnWT#0v=4dh2 zR4RUWN}I|24rMlG2TevQ-JxJ2pad-OIco>11OlZ}yDfJ(0pqQRI+ah;>j-3iq#K0) zt_ruSwy{yT`!FpiNKa%ep3lvkcuES*%Yc9Ih6dpd4=qMUZW+o|QiRN$ zyh<}_K1yn8ZYpawB)_YXG-J%Rl4cz-oR5UzPs=Y}(&CAI&#b>i zfZ1!4dkED|gG2ScIi1VJpV^mHK3MF4)ZPfJ$f9FE~ji)zrArUtu>~M1{5=D<|41HX;yFX@yFe zgDN@)(*!l1VvA+scV&PRL`|yW#p;fkVnIAQsa*kA?2R_5c`YALY6=FS&d`E&RLv>T ztOgIEFKutN|E^+yxfSU~iL+!k4&@4sf~$7Wb@jZ0feA(``V zw!CbG1{X#<)^dW}S?Ivaq6KW+ia9t|WWfw&jCH*20q|0gFB16U2WNmYz!~5Sa0WO7 zoB_@NXMi)n8Q=_X1~>z+KL+fAO>p7--zNVDkzWMC56%E*fHS}u;0$mEI0Kvk&H!hC zGr$?(3~&Zsn+*7E_Jy+f6^|9NU~i~3z68+MHPF@Lg;d$(vqU~C|K_!sK<+hXfHS}u z;0$mEI0Kvk&H!hCGr$?(3~&ZG111B0`$EB71+aPT4IX0=zySg9`~PrM$QPyeNj-sQ z0>=X_{%`wJ{%YTsd^^3*d3A5S=L?=~?%%o}aBp^^N|SH7+C@~2HB>=0>8Bd;F}r1=aOn} z68zoQ41jAG0OqxX2HSYD$3iIMQ7k(9n*e+rQ*a9xk;Q<%kpVhvotucqpuA^eT0W9Y zk545ubbqG;20{!BXsH-%4GrbgBUsdTHnN3Hq}6eN9a#jhEJ~i?E%2ZvBCzKwdI-r# zECl!wc4SP#mbQeJ4H-HD`^8f^*fg8evS2N~0IcmUpS4^b%&S@CE#6RGQ<->1!=0vR zYJPc5jmIY@V8cEVc(5K!?JXaH@szP|A8e)Tz*e+;ws7A!G~-zfwqnv6G*Jr%x0TP} zR6d@d?bkf8!=n8%s#LMZqZJEi+f@VT-Rx3fdcC(AaQoQdTAJKlRg7H5O{=geK9PV; zqCiEY5^z~GLp47y&6m09qUH)L`pah0kv?8Un!LLB%&_V>7{G(bHhNNKYKmcs97{`{{ zf^ly$lL!0}Kj5?IA4~DgJlW#|_*D$@B$1B6XdxDYt<-til<)B}Mn(%1`Bo1g4>0uE z(p2nl19m3^Y)ZG;wvD&|m3=0vQ{!>0zq_3P-^ZVWkHPQuWK@{Y zX;4ULUnPPC7JZdrfld%2xeWZ85LO;LBi>8Uix&XD*9Pz`x-0|U!u)mtpxYTRU&Mqd zf3V#~=l>tV`M)fGi2PF6LEa^AmwKgNOZVFE4*Z+@e*(Ar9`^0>{=!@E&i8!LvpulH z{SjBtIpF-b<4H$U{Ee$14vP1=p0l5IZ}C6wukojCx7*uo-$lqj^oWRsA%Q^xuQC5` z*Tlk2<-jkQ|F=cO!u92_VwwN9mdl2vxGg;(76uuo!veqLX0g!B0^Xeex9k@SYghnY zZT{apA{N%M1-Ed~G%6N0vOqT)jVtE=i+72I5DNpP^M8dgtTF#@WL5I6KL1~o6bmB^ z$@ty#|Ahy|!tOF!v&{b&l-ZPJ{@+k$QtbSI zTNw?Sqs>}Y?blW3|MPZ>g>HtaFukt1RV?%|#I-cJ)vR2mgZkO|e-*2yoZ?<({$DvE z7PhkBq;&E(cm7{dUNh!gGdL+0ZYfi0O6LDEt5&3xwyVznC066uQd=|g{{XB0u@v9T z6aQYZu!=>VTrvOmu|}Sm|9kg|g#ng6Tbc^bPO;F*0-Mrpwr%e6WTH9$ckK}iz091p zu;En2!Y~V9W2}1R{NKTP!q@ubda~-~anV z*C=>+&H!hCGr$?(3~&ZG1DpZQ0B3+Rz!~5Syrvki3w9E~{r`e|iNGH}I0Kvk&H!hC zGr$?(3~&ZG1DpZQ0B3+Rz!`WQGGMm{Jox$luro{Km*ro|Zzp0nWhdl7aSNQSj^-Ci=sz#q`;TP80%y*a*Bi7f#{8 zSLg~&n;_`sO(dK?J7|B^cQf3g52>)}g4+haUAqy&C>x3XWbu)`djs4U-$3+d*Y^t$ zVFkJDJLh|I1YgKI4e?Ram-AxyFp6Ns5JYJiBKiydM+)P{}HtwLqL!n6!AQ8dh7059Bqt7~+Zp3q<2M57c(;(4*)O+TX za8e{n!L}Z3-65_ATN(uUt(}G>#0h!@5yBXqNQ9hqfyQ$<6bh~{f>Y`cK7s?I;4wye?$8vwotfN{aBeoF3#FiJm}KIwe(pbpL~=z|-p`$+h| zgq}0!edk=@8}%6@gcn?gPp>24AK81RFWW9X8bm*g=fT=l@N~s0qW{B3giEz&-j60Q zTCjE{K3GZgU!It*KL-WlvJtVa7asKVLVm4RCiKgW3$MKBgD5T%+|q-^M(Tmmcf0a{ za2g_=Kd~4SU`+vlE5;2QTADd@IHS2A?CAy*6WxGwLVs%deC?U@uRMEk0~TO|@tHW_ zXc0@SUIB5Itswe;)ju)_X~0y7LR|=*4Z+=~$mK2i>5Fco42dhc;K^7Q(LetbCgQ}P zh~O&_PlPbG2eAY}0VOm7Jc3tmLyY~Mknr>M7hVD1;U+B#7|1y6L9`nVL-5892p;aB zxk64x^g$?e)^ZoFUJe14EhqYK5bn}q3ZraUMoC;o^po$t{h>PY>iLKcbrJyK-#BJ? z8-1{T(y1%t!76xS4+qKjJs8BP>qM_?2l%D!M6ZXqO-*kl9YpAFgL~OFNc7#$e=++} zvf!uv?}9&}PRBnYxW6v|0bw$_08Kyx6V$g;y;=o^D%Aj zxZmA`08K;>cP)mq1B;1%#IBJ^I}{0h2)xmORtSkr0iM4Z)EOu(x_pCzgw0b(`0}$~ zeGA-%2lp2-Wi1>}x`JX5>95y4s-0qo>a;o5r_P4j92~Py9F?e2UbOQm^6`JoTAsaUA|mAp;suxR1Gq;zlP|qp7hu--AKY9 zKBEWHK6_%rgop7O&Oj@0R&1z&L%eeaq(LLrXgun4g@%Pwg$T}Sa8#-mmQ{h?RYmmU zSbxsIKY~HgfmQ+&n!!782?fg4RS97aR6^ozcS2`$<9fUeorjDvi~{3g(?PxuaRqk; zgs-e1`h8Lm&K5z2r@@Lyg3U5qzdbpb2+A%9qPU2jgaAUA1RI@jBkUx) zZ*g7t?5WGAf4%m_ZwchWvBMDjvC}VRD;yA@#X`tS62nmK!kQ>uiJTRA`ULlw0o_&!wBOkU%Ss zPVfkZi6R8a9l|}Bp>Dxx1H+O{a0{ZxQ%y)j08YrT)JWuu@=Njy^3UYc^3(DU0hNMrEf{o(%(v7kUk@QT>7wdpY(3&9;qN5mvT}<(xd~@Zs}&} zdTE2SMp`L#N=v~NesBgj1DpZQ0B3+Rz!~5Sa0WO7oB__jAD;o7+zP|PRBhTs)y9of zZP-B7&=6JOFja$tRIOi6)w*?5tzApinl)6dUQN}&09F0{RQ2^yb=`GTty)FZ%9T|0 z_EOc;LsfS-RV!9d6$(++)kRfjCsiFCR4rdl)v{$&wYO8%)<#upD^*LEQnh3WRV^)4 zH8)e$)I`zp0nPwtfHS}u;0$mEI0Kvk&H!iN^~(T${?F(CuiyOT zDdP-q1~>zp0nPwtfHS}u;0$mEI0Kvk&Hyri`~SB{AI9(hzbyYseqR2${1f@dum|9u zwlS?L++U!~F$x7mBZ@^R7NS^yq5%c`zC^%YkD?Am zEsA+4YEV?8s6tVRq5?$_g^WT%5kTQb;X~m?;X&a>;X>g=;Xomxu%ob{puhhw^pGX+ zJ0|ky=kv}} z;(cUV?6E&#AG3?LPdQs0U$bp;L>&$r5k3f@f28BESkPF!j8aX+(%E=^GMNkI)7j`` zB$`fU;t8$v3f}Ly4u1Bbl7+?6cdx#dU%xvp7H(j{i1CujOje5&U)h;W!^=`rS`KI* zX7yT4OGY2f_2dD16RX#aTA*neQqzFGfz=C+=b}fZQn`F~Dw?0lYU5yNh}8?5uP{_% zCL9$D*R#m&E6j`^OR33t6p}KS5(|;C*f49g@oMOF7R%OpRt2Esjj*|l8r1;3-0z!J z%lydPN!9i&!RqIFd|{YY8wADueG6 z>8uLSU1exZClYFIGNR`5sa~+qSq=-(P;1$!8lA-Ev?C%Gh8Wj`aZSsZ_{+4v#?{#>0q5Ef&IniJz}Ak1#Tiu z$F3n0cq6N!5?yvH_6&FmBc-VX_79b;mQP+k>w4IATqQNe+DSZ}0@~z5Vqq_f3@2VR{TkK^>+#i~kE z@1yC2F~qE5Ri!DeW$58#wV09D5?WG&!DI+!Jc3PvDur@j#3J+y0cl+V>^jPVaMG-COVZf@hoix9$hr zn_XvJx=VI`%-Qewo?}dWNxWBFZ~wXdt@bwC6Sm{F2I0%XE#w7&{KJm}?II}*F<9WR zEVWu!v&P#1|E4nFm#j36wStxF%V5Q_W;O8Q%-TOvQVj2fPt+HIH3kmLp45~qR34PT5Q|o%l(wrbR2~R`Ef(Y0Qd=_%l_P$@XVE{F;+uK0 z#|QAM803k$ursq{w8zUBd1jgGRu3Q#F!b5dRP1mAb|(XDO1Igzjko}neI}aA!MmLR z-^ZVWk7cxMMuktTHCR`~!RmeyEU@US3=73C6Pw@19I-Rvy#&2@ z0nYEW0X&N?%Ye5qzg+<6b_UE_Fme9BNX!xW0eL;VvG2Il7EwR~gM~4CT1_sTrfqT{cE_*kD5f79zSesKdyd?g^WSZ21ke~k?!>0U6!Vv+P3rp&e9oVkHytOpD-E?9FnknHJZpEJb` zBzso?GUKk?*gz5q0W{=inDJb1Q^1lnE5Rd=> literal 0 HcmV?d00001 diff --git a/ [conflicted 5].coverage b/ [conflicted 5].coverage new file mode 100644 index 0000000000000000000000000000000000000000..585f85436ec9bd4a94844b0db3637b55d6bf77b3 GIT binary patch literal 69632 zcmeHQ3v?V;dA_qd``(?qk}XTJY-=pZmSsy?OR{6hFRg4V#!Z4_Cw|3`)o3+Zt-UYr zuI$*s&TJkez~ON?5J(9j~R^JN*@ymNTy=3({bp-tBhtL1y@GQSE(SC=}Z z=CjAuoSs+5lPO(IX2z0nt(eRt)Z&D0g`O6bS=7g_n7~dX%{qWQdujk{1QRq;j%nJmu`?6-3Ng*#X5?{j zE|axror#F1Au{thGnpO*$`4PbbC@|0qindCy;{tx#Vuyd#QxTKbA`N|NpD4VT9jjs zmSat3l1C@?>CEp`XJU4;WMs0PDkcI-z>%Wd#N4iH#wW2d!1Y$kYr~Xo-+f(1vB;Da-#mLMpQ+X2HXhy9^XT|d-xwd9!LVm;=Gx8#vM^>uYp zsazwlUBZ7?MI!LdOcj{b5&ZV<=d|L4db#g-8ys(Ux$D=}UCnWHA=ZP^edm4$bNQY| z#(o+aef)P;xQusZT)s{}(REBqL2HuJl6k1I&<2-#02Kp{X%JW&%}y3s`Nd0CJhAVY z@wW^xb8T@Cp*mP_sNOeb3xy?Pf_6e8$z%C+W1788iDU?1syz(Dq~HL_6h11R|>K zP$>&gMQ34JpypF-u`K-V9B_iDX>F`r-7!;ah{q@N%izkr(H5|)UQ|}l_t%~SZ-%x1?be%n&q~6>8L0q zvmVZrm+jEt!f3}{PLMkr9e7!`fQ?%@2S>{+SfR|Zj=v)WUP|;Lfj@Bw1_T3w0l|P^ zKrkQ}5DW+g1OtKr!GK^uF!07>z$rN-56=G`^gjrF34}NV1A+m;fM7r{AQ%t~2nGZL zf&syRU_dY+7fQtK-oE2-pcsyDv`u%>&UyVnn?a?rDco~6HheI`(NEf^3C z2nGZLf&syRU_dY+7!V8y1_T3w0l|R9K+w5RvQ_~cerJ=zhBb{KAF)>S;)uDGcA?K=99&VbRjaT74%3^OB7&p$k;N#zo82F z%!?SB0xP5KU}akstmO4XvQW&QKpRV|WFro@=w^^^|Vc=9JAD3ft4ItN++d;?c-8yB&~fWDanI&7UAPbQ$e=M#D{ zmd=h%rgU_FmkI_V91Q501Z)kB6trVl)OR)Wg-vC(F@Qa|2w-`XJkwj?K~KeC&sF>g zl95^n@CW&kF#%iJQhGjO>KN=7&lF(OY+BEQwd4Y@wx@d53Pmul<&n2!Q*}+{k~tlB znxd)s)ipJi93O`b`$*v7MliLndITmj=DvNfm2Ci9@#@*aec#ZG=XKbM$!gGeJs8|k zJ%f|QWQw(4^S};|_RFkN;K{CBBs>djkN!nnRwXvI!V1BqFety2zUHy?)NfXn`T$;sfLX zjy~I(ikrQF-NgZ$(QUPD2R(qwKNGddu_V^tJ#K*SC?-l6f3xMC}0C*lQL<@-F(JI{Kz2A>Gl1+$LZoE?sDA>{9R#N^VD#36xK zng4g_a_NR@;8)E5m&N7MwrW_h&Hvl0Wy4n7mL8HzgPhZ0gI{u^Tx#V3Z_WSP4#=f- zJOHmW|8G4gmp1SPw{g)jB9}JvKsOtW%jW-!cgv*+4+E9+f0Z+=Isb3wRr0Pp|6i1r zO9wfU@tf!W3lGbsJyo=3oBuDUvMJmAzp2WmZ1ey5N95AJDiUa)|2O93QoM?`E}Q>1 zB;?YLDjKv#oAtcfuPe^~=k1Y8JseYEd0lslTw22s*Vg3D<>fLP)X&WS=kRLE8SWM4 z|Fz?CX*&;2Dkp!l=l?a;HDk>+!xM7prYfbTV*XEgwIY?YU2*=e@EXTf+M1sKhj{gm zmH1Yk1oz3M)jaa#viX02H}drS-@jij4e<2Y)>Qa*$)zqH*o_0rP(+uP!SNlV^ojXJYz#`Mf~< z{@?GqMj^rr1_T3w0l|P^KrkQ}5DW+g1OtKr!GK`kb;W>Fa*`14|Ci`11pdS!7!V8y z1_T3w0l|P^KrkQ}5DW+g1OtKr!N41k0jD$M!_WVR-6shB75y1~8=a)1bOY_AJ}N0c zRQ^$URQas(QRS49R^qTHz!SO@dUt4dXiKOw)D-dt&jp_gemj^BHU_>Gcp&iJz(inw zU`JqEpx1xF|DXO3`S13R`B(d$zUO`4^L@m(+qd4=>hpVF@IK}Jn)iP1-QIoP?Oy76 zyXRI>@q@lI zw@asGqLv&R!PXt}MzEzrkY77lNJ8A8R}v|T(TPIHSr2GDheKuH#v(Xn4&fs>FbW|)8I*&Sx~lQ5b}a!;|6f+^af&lrFo?H0wwZ=Q!iaS{m402m?eV6RISJc z#D`?#P0S}a7I@b~{Mz-za4cU3XS>%C<6`557>f)+qZkN94o{3S7gJL0S_r#zEioQ_ zHTYWK-1#8@Y#D%i*A77WxYURYUD^cjWdMu|W{ulvKZH^H$<%50TLukqX2BY`v33oK z{#5Ecb3Sm+1HLhzF+y~~HTd)z68(X*cj}Vkl}E$qhxt5QzZ#ydTuqFxeMEYt{>%r^ z1V#(jufhkbi1D*iQ;p}KU|cdIHuS-R-ag2$jq12@$#vn?mjV#QL&95ovDhfRQ2K6D zAC#VhNas&2#st_?0N|>5!-1Bj4+G9v?uUDOz{GeD;G8m^o;qKD=KQPAUfhHQ*kpV< z4meuIlIO03ILlWO}3_=<(6|&S2L1!ay_i1uztMS}LuUUrVmEG`Uw3`^8e;N~U zYEVY-m53)p7{^0cf}nsB76Bf?8@M6Hfi6h+`Nj*cg70vX6$K1r9?me@jYc7Ob0-9k zcCuU{r(?z-6gqpki_TpE0hX^I#xD`>mBkE3-MXBSxSSZL-+kM|4dk`+F$3x(0K&g{ zO!GGKP~(K#P{~7c;E6LDCg1U45Vv6vqqYO!mv#`N5#qMAyp?nksec*V%P)gO-}Ay3 z^B*M(e$@Xi_>&q8{3DY0I}#8OB_j*a1T-*V1AGn=GN-N`Vs^HZsWo@H3H`?X1LI_h z1aEXNh1<$f$-4w}?Gj?R0&Q^C&_;~4Q?E$NTH#`6E2A)ad+b*)tX~SKUL={}Kbj)* zF>UX7z}tcVEyRd+FNU*2i-~c}sgnsO6bWMpyfJ`Q2#HMrp1%du87M6Vy-r2K=BXrl z>Dj-%8*amK;Z?IqU{}l<2xQ-2B?fO#=LOsXaz{X(so&X*U~SC+cKSGSK+avdJ%E}W z7)NGj=rtS|#5~1ID`j|28sPsG2Z z=EMEA`Ow@n5-(I!aELZy+(u&jOK`pun$+UF28`Q4jISmLxw~9J0l^Wjhx;w{#Q4>7 zHE<1pYv#e7+&pM9UJFL=k-#t>SQWid(n5Z6JG;L1(xvhVyFwvn>X4xWb;Nk>w9kR* zMiK_`89R{nnG**lJc`$F23m!)azhOq;+-=fO&YOA&EXxmgaYO1u7$9NY9VpAxuG+fiFnggwA&KasbA&h>rex-n#A=N} z@CZkVDuu}%(!H3IUdinM!-_-lO0v&4mynnQypUm~nb3>$W%?rhFZvw)KK(BJCjC17 z8~P>sdHPBE5WSzim%fAEMNiRVG)pI_MsJ}v(d+3j-ALEam9&Gl&?Y*UhNxS)q+C#b zuKYy#k@9cKKPyiu|Da4Mf2Dju`Hb>$<-^K-%Da_&m6CE&DJUsLR}LwAlpB?8$|hx< zvP$VvmVztd5DW+g1OtKr!GK^uFd!HZ39yKLa?ym4=6z+OmbI&6}Cpw27&q zA*P~HrUnO@+PIOa4I7wRzn-ae>zG=*mZ^aOruzGtTC;|!Yp!8x^=hV8tzxRLkEz~X zrh0mqTDg*`NQ9~GZl=1rnCk3gYQ+ksmM>?jql2ks%b03!XKLwErj{&Ws;!NwRyM1b zT3XoE;>Ap$k1A+m;fM7r{AQ%t~2nGZLf&syRU_dbN#$^CM{}=QBH*WrllnDj| z1A+m;fM7r{AQ%t~2nGZLf&syRU;r7w{r_8)599a$e?@;zU!XsxKcxQ-djP&ipQ2CD zZ_vl+S79H(!}K%s&*`7h2j~Z2FTgwLz4T6cnx3FV*bi`+>U4zer@QG6*b}goZh~JB z=%;8V_^`+u!GK^uFd!HZ3qEoYf%iK=tr>z#Wg5aqgaKa4@EDE9uzB4L{N02=t9wn zVg-ujC^}FqL(z_6DT*a1+EBEjXhE?Ug^HpX#Ud07Q7k~wgaUq8BIImD(SV{J#XJ;s zDCVM=gQ6Bi4T>-dib6pVLJ>p}K;cK>L*Yf?LE%Q>LLsAYqHv&KzyB}wk|pp1CiHXk z&9H0#E=3JZg$@N@3qBBB8+a~oI?(EW!ms;Y_Pxy)@qW{r^9DVC;d!(B1^4apePl}R zbw24Fb;^!Ux!YWS=h)(kyIc+;-4CF@r{kzx(s{g$QcESW`DAe-U5FI3`S?UEo=xYH zDZTOv-tV{pe)^%7hsDZwufCpNzdI(EuH(Uo`I5?9UXPVu*_qG6%ThCX0caoQ^;%6^ zMjtQq76Ez-uh)&*pjjExvVgvc*9(qkqsJyQg<^g(UYyM9V_;~A*9)63GgM(F8kbAk zc;xnFX2worv~)5KNg2$@rC3#LShd=GHFP$QWosj^0#NZr*g{T=>wsSE_swcme;qHc zs%O$;dzY{yC71SBHKT3D*7IV@R&onPZ8W9F#rB?i9|A!7>9N!hDk|giv3_}RZcF&t7vLEzb%7( zRVY{!M~_w%tllc*4A=m=dU!Q|wt_XGrN#kzWfj_sM6P%=md=jpDKHSJg74$myav$S zRcKAeQ(9pnrWK2sKCsYL4GYjv>-o4ApTOp{GbWdYIM;+RUC&wgE2@x!c&wP$GKG{@ z#1t&ALJCH;0!(|f1P+Zls!(+qeqK)`VO)6vY_#ig=?2cZz-;95lNl|Br(k7SwNhdm zVJtl)mj*ef#0I})r(EjrMRs9d8X41$CC5hfybiMfJ(-DY$;KzsdZt*2UQ4H; z>4$Tmz?6l2+&r^k;`jd-JKsZ`7rd>WCp{@o$o)C@kn8)dqb|SvDf(%;-T5B)q`m;Z z{r^CyC;0W?{-7NAL}07`od1LV^}Zka?v-!!t?@qN#h)P+hhRW3@c&{UA(wXWXhHN0 z4E^*-TuY^51^wuxo`FFi^c?XcIu2yF?vYD9JnWcrI#{PkVSjP$Ub)oA12>grW7miU zyqVWfi7mU8dj>p(kV82`%;6Z}rvyEUMuflBhgH|_Ew_7enc=&AM7PMnnwN+lF z&4exYPu47c9T_s)<)Gew9=uGvhc@ zWuJ#MRykSL@*^Z`KChL|t0YUCgwOf(1TGIJSv%0cYrV}hXloJY?vhJgJc>E%hONru z#qa-jkarULS-Jzhng2eeH}p*CWT-9pcrX*38~9RSm;ZUc;cxVP!MDTvYwv^JYdvQ@ zhKIU8=I(cW+che`EWcmg==`zstRl$mF&1@9j9#It=wi0)63D_9qn+_ZN_BMcT}uu@0qk0Cdj{d+ru*0MMGOJYO3}ts6p!e`g zh2{0Wxq!QdAFi#*-93ks%d?hCVzq$FqbX;&S6HajYrw#E4xBIpC<;1!YBc)`h2RQm{YbtK`0(KV%Y(}@$wjJ~UD*sHhmV@`W0ltry z(>69nTmU@G0oWX?PMpZ;`J4uySnIH^h=bJwGFafzSD6;dUnaJ`k9p9^iT4Wh@&!1* z&jIi}x-1jk#{5nRpgTA)Z^Oj-|020S=!0}4ys_`3(j0m;6c71=p9roBJQ2tS{Qigi z*ZW@bz023>eaw4<=ULB5&m8xs-PgN*;rgJfPyPW!7KdO!Fd!KC?HCx}3e6bDDbLlu zd0~WMB>NEMg99 z0$?8fw$106jWmfNK<6=2n2kPrBaI#fLp&x}ml>+qNHYrWS?9fkTi&`hvympf5p2Xb z6xGX;a>b1_xeZ`Wucom&GSOlOsw-*%~NcMH}_ubq;GSUUmoLke%4J0F-Rj^R8f#lE%u)t%X&sG!c P8%Rc$^G`wL29o~;HChpF literal 0 HcmV?d00001 diff --git a/ [conflicted].coverage b/ [conflicted].coverage new file mode 100644 index 0000000000000000000000000000000000000000..17bc6ca659637529e75517c6a020317cbfeebe74 GIT binary patch literal 69632 zcmeHQ3v?V;dA_qd``(>dy)4PHt+6Cqk}YX1$&MsHq?K*iO_Si*j$g6m)o3+Zt-UYr zuI$*s&TJBJfC6a`1Sr9T_u<(T4v^9|QGuS4q$F@!dPvhI2F?Ko@-Psvsgt<>e|JWj zm2D2Rci|-5Inw{nsFfdAsk7OZ7zI0Y5kcoB_@N zXW&npfx>pbtG2#gxbtLQ9ZzWatQys_MH$$%V`TW25#^TQjW>=c#d}JdSAnCmQyEsW z>0?Sp%PN!cgr>w(6Y;2;kEdcveo8Y#Pv^7=?C9Vgbafc8i^j*ZNPGff<+T_n8yyn$Aw3_v2YL6`j&@9f~@e!e|=&ye(f$Z8n_@<L8xI;Cj0Y0>GtHgVMib|7iy0pwXz16X62ps`|1!2(A$Khh1t ze|Lr3RomDooIXMe3epo9kLPoyWC(cDJ>*eixLSXJxQw!##V}% z1EtwoXF?_!Y8ES+k~RW@r+?JahR%Wy%6K9@URu^@iOFeDL7*)#gH2ASqF7aOn1*Qj zsFqb@T9@&(*ev|93lbEaZjZBer14rsXSN7Lx@S+lg?gu}wxL0|(?g4qkz0mxl@uW} zC$G|snvas2nw!p=4ask7B+VFet)y8;j3REE{Wj|I>3Yh3 z5*mH{x0kq#cjjEaK|9`gR82r@l2PMXsIt%o7kdB&15T(ASRGGK=V|%HOIkd!@0s(r z2rzeTau1=}X>h3CH>Y#CICQJnZROUQca_#DHHWQGu2mUNrxTi*D)tQ!TkYIAX+(wo;$W&IOzKTcFaJ^$SjtNi{WL^jFvo7g3?@$I6Lzij4?FR9c}@ z=Aeqs!!$vSr`Ted_+1&`1W}XfM6tSKrdSY!2YF^97lbV7-s57)+9aD2k zG^@cw=u6vM?Z2%UU~WaaQQ|DwjYGLYqu{Dtxl$-h8yrV$&AwPRzv8h<7VHhR#+Ly4x(2#>ypSrJe2&QHh3~&ZG z1DpZQ0B3+Rz!~5Sa0WO7oB_@NXTW5@Z(l5!s{l5yy}@HF0yrQ5e*Yhi3i*=s9;qkr zRNzFQ#s4jT%3tmKvTv97Ij`=m_k7W_-TfQ)8TS^~IhXE|ogZ`dJHG807he|ND-PO! zVt=c>&GxwMgsnmNs<4y1K>jbg`VY2?q%a)Xm4l5Rp$YY9d;&HHYN=dai>E@H)6waq zmdfWsnQS_iRg+2BAqnTxDbsi8E*wE*O$RcR*S`R`RsADvAj$+B9VA1o{vPp#^^GzF~l|<7Wi!~0N>02K9^K; zQ{eCZW&m8r05Gp5G}y+IJsv_Ck7Cg|&;;Ncn1Wllh%5#4O$^Xs>)d2K2IV~))AEsI zdSW`Eq5HcOFc4y3Kug77YiKB^9>t=*tC1~iBCSpU?C27JWl{1BZ-ECb5rI8d(Zfha zVllvvvLj;(wzMU*Y{<|N*e{;S!KT@ymIZ6^MPO}D`K;ygU|!84Z}Eomn##m88tybj zQwz&$Y9c;42^;p2z(e(5YG3&XjHir!`(P_w2ezW+vxWPgBB#Ni3il32^63Z+y0kQ!36|9i{KN+A zx6d4p$zWz@c~TS8Qn0Nbclm>%LlPKb(TbGP2EPK4LkkgT@{eg*7?Ogmg8{I`VjNp) z3&y?4Odjw@{eaJ+e=Nl}^JK3N;MXw7lSDcOqlH)qwo>P5Q@+>B7#S^41`zjGEu;{A{3v_}I$z|Zzgs}418S!3%Uc3PKeKvq+(PbI%7Up*d0Nu`j`64Dv z`Gf5?I{*J5&i`fkgXBfw7V<86htwpA;5_g4R-{u+PUcDudZ_OA%}`yLUoFf1@g;5Fv| z?V4D)sT}wv^Z&M}Sh&6%RxI=X)^gdf6u0FE#ljHdbXee*-7FTGS-_j~|CR${VI2#= ztIhwLN5#Shw%`^nn#RP!CKl*MqjAOjf9YNt8Y|1kKUwBw7>?6hPaj{x0;p9bWlGx|F2@zlvCWR%>OGV z#lkigoRm)f=Fk5t%4^1)YX+yp!p<_KreyvvvuZ_3X}jwDUt%?mEwwc}{|~V0A4~Dg zJn`=n3u{>9$rbZ|A8X{<`M-C+SQudGv!$u<>=Fx|EU+ovX4~d2PbQl4f7f2I(96te z3mZ;VER3)KHpZ$~&i@^(`YOXh>HJ@0#d`^Q@dD=mc2-?h6eiCStIovm^~!kx|NXz; zca4IF=L~QLI0Kvk&H!hCGr$?(3~&ZG1DpZQ!0U~)-2X4guMqg-2WNmYz!~5S za0WO7oB_@NXMi)n8Q=_X1~>z6KnCpgfCoSSA93DJXP)nQKH}Ny8T2gk_}wqMzvF(y{XzF>_kQ&sEhf}PbfORt#75v1x^M~y zzDrkV+6+OrY$oCKxgq;&zMJ6|eMp5(7u+`Z5!+1=M%hI4CyI~k-5cS?#73e&GuSUc zgjM9S@4WBL5qvZ648%uG-_ncW!zh9k!w{uqnCLI~AByPD!kf2+Jzv2iOM9tur@0Eq||uRw<3AAQ!rbR(V3bsM8b*DH8wloOx8#@h2h!gZGB7`wIkq9~G0*&WzC=^^@0;kj=d;|wZ!DEb& zZ_!&p>W$$vcv55*lqDI2JZ~G^0B)VyK=iLRj`duWiFonkOP5YPblwr9iJ&nRGqMix zA=!8n^9hbc?)4DAay`*)E7rl;?sY`JRDUr-BSX+I211d+6TQgAj8M52!Y*G+^oL*d zzverCVHg0L2H@V70Vp4r>ye?$8v(uufN{aBeuvx-VU&I{bISSVAsw7q)CV`#_L1<< zgr2h(eCJ)@8}%6@gcn_hPp>24r|ms6mu;^+97I2i=fT=F@O0H0qW|NEg;#3Nz7I`c zv|#OOe6X76FP@yKKMw`tvJtVa7asKVLVgV@llo=H#aCbQK@=AWZtcNhBlSS(yInaW zJPVO7oLq_tu%-aO72}2tEzKS}oYCA5_H=`Z$!@?osXsY$q4w;BSD(4G5eu-v_-q_- zw1_2EuYx!$RuTPY^$!g}8ZZ^2P!~dHLvZ&=a(S!%>?OBRhQw7}@MOG;=wEme6LE4# zMDSIJCqfw816YEffD#%39>F`fA;y7DNce^Ni?4$3aFZ4V3}hVkAleOwA$Vg41P^!6 zTp_0-`VbU4Yq<+ouY>?ARucVJ2=~fT3ZraYK}lRe^i%J?{lPl&+J%S?brJyK-#BJ@ z8+)LB%Bd^lfhu@n4+qJ2JQ&2O>qM_?2l(afM6ZXqO-*kl9YpAFgL~OFNc8FFzm)wb zS@fg+cfp@fr{fpevUV-QjD2v$__dubp{CXlsUx9nF-&@Ews~KfiuCpt_M{ivMteEX1_E z<9>G&0yGgl+_e`)~1Vem!=S|KDh1$h2uP-mgE=<*E;61G4g;mgnb z(_7#+92Z|Tngn`9t${%Hby{NZ_H16jEg*Lc49-X zdWK%ZfkBK@yfjmW=cEq)U$K$!!6k^fehJZ!Jp_^A4`71mNwK|Ds8?_wW5WeJwSf+s zAgwK7r`B65#AvK5G1^y0d90P&>J>zfb>%{O2tCB z-?9*zn|k7gY6=d?^%%FF=>OqgXon^>zn~7|))D<{F+$!_te}8k3)aH@rdp!^`q>J& z2Ee`ra3`|>nvB={;nM;b#sjOOJ4~9$Pw$}Dw_UzmJfT-8#8eG3bfAXlubuMPFx^PP z5I&;^(mr=$!-R+N8qPo~a8_)nfkV7=7NkKV)@VHHbcKe6Q-ui5YH(Dl7FJY&-d#oX z6Ig%F!aqD!T;u@{(J`SJyhBUUMOSYngq^B{#NF;B=PutgNXR|p!P*yOLheo8q!++j zC`1nOIii2>_OG`H|%C6McBFeQ>;vkcdNCKG*4u~Add6d)K9 zn;>FyOBh`u`kw~_aMl!nGcxl_fsjw!U+ITC%l$;Zbq21ft7j!2+*{%!dh6W;s(Wy; z7p~XC$K9hE@d^s2L2L?e_B_a?E>g)QPG;FZAO$Bl_A9+@CwK(IL=l4IPT?Lb18%`-1H+O{a0{ZxQ%y)jz|a3jq(&lN zl3$izkbf*cD}P`9uKbw%4f!ABFUwz$KP5jPe?a~#`5p4z@=5uqoR+6#RlZf;Dc>lM z$b)j9yh?7Do8$($S`Nrg>9TZD`nmK|=||FkN#B#6kp4xQk^WBlqV!qmzp0nPwtfHS}u;0$mE z{_qUo%vKl~p=$GHsy1z+YU4(#hKH#Nhp8GGqH1uEstp^cTECvEb?c~FyOyef0jm1@ zsp{*a>bmQwTC;|#)vKxM?WL-xhpO&ws#dL{DiorstBb16PO3UOs9L#_sue4!YHz2i zt&OVIR;rdSr)t?Us#;p8YHp^gsfntkOQ}*6su~-qTC#+y#fzz0MCboPLj%29xR9#) zdaCN`sH&}{YQX}kYHFyeuBNK0imJ*=swyg|3I?f?WvV2Jsz87$zn?0fk1DU1DvyUM zx0@=Liz=s+Du;tAQKZUlr^;p%VGV$O{*UWq{NN051~>zp0nPwtfHS}u;0$mEI0Kvk z&H!iNjmrRj{?F(CZ`}OlDdP-q1~>zp0nPwtfHS}u;0$mEI0Kvk&Hyri`~SB}AHwhd z|62aJ{Ji`V`G@jODLn?iGB6|kz4R67Z>7&kpOQWb?(l;%z!~5Sa0WO7oB_@NXMi)n z8Q=_X2L6;7uy2K5sJFwO_<(%`#by+nP;5joj3SI;2*n_Z4Jg*5SchUQiUAb;DEd%b zhhhzi)hK#V^q}ZQu?j^9MHh-r6dfp5qF8~V9Yq_8Rus!oEJM+Pq8UXKilrzN6pbjB zpjeDz5sC&B@Us#Ddp(Le6tyT8pr}DnjiL%gC5j3ZK@>6y2}J;fAB7Kv7lj9f8-)vn z6NLkXh{BG-hJyb7ztBUL!Oxe-pO@bRyY}yvl)y~jp#L@h{rgP9&0a9FWy=xpWqcTvrD799LT5QFKtru%qiS>to70Yn zSQutp6DBk*W8$wYLkgmid{#~65^5e(u%Zkp7*}&J?NMVmG-@wH)us4ZEf$Ay<#Djl zs)>c080P}Dk;zV{)Cit}m9}!F#4^HIeo!n7F;0mEe%TJO(C!KC!oD;#p&pG-jB8m9 zW&v6}722GRPA9cgJ{QVl)3L0YOhVHS=RkoWbJ?hIX2Hbo|1Y)QOY9fj&928?30J`R zdFQa>`;H?HulQ;CGx9e3z3@qWPI3kA4|My#;ot8UeV_1c^`7^>-@D%PL(e_p&7MB@ zQ*QhbQhsm-I0OGL24Z4iJBt=XOTo}j3q{pLB9haNOlv6^1VYadJ*?qCcIzIo(9Ob* zF{gudngsS2TlR{DUKY5CG#$H!OyEtdhDvnVt=Kc*DU6h+6WBjgvRXcQ*5ER&49xOy)sB=uqLRSX1Bvyl+Qd`hpCNnsF53uUKOY!4L zHHIb3UmkdMdLoX~cNeQFO}&q%6UGp;idB`SxR#-ZlhtBIUQ1|64F;1Tl<_DwM9r+$ z^ya`Vt~cE*7MfX&sxpvSv08Z;vR84NHj| zT?RIW*rvk*zpVw}n;F2HW2pVj0Jx3;;M|aJZajIQ3BWfn1-Eb!SqkWz7@!-|@w}Ri zX?gq@04BHC$L~_WK!||>qhT_}J-Zs&!Wu(|(Io)OqU0Ig&KvL~76be!J2K|?;iV&n z_#&{jr+n5dBZhcGc}-bH4Dp5KHD&3=57mRIedQz2Iygz!fvsryY+W&8n5+eZ+skLr z?C1|H06Q$&FQZD8&QNyO0D3pORG41xs|MUYcDR-%cXt&dm**{)L@EK7MN>|3ud-07 zRe*tQ3^<_%P~f0@A7oGA93I0dd79aRptDQ^G3%n9q)Jaicdpi zesBgj1Dt{1je*In(2Oyh@?7gX4!gI29Tt-kqa`*w?&t{M-oT(4HgK=H-(}xsFv6l6 zu`Goe`(5^I0%#T^19Nz15_4c90JG?~Ek4g}q=^j!I*XaYeDwJnX|ymHVllzG!cfUZ znsIo`I_o9e;?}jfjWo$Yun}QURIfpS&>+pK8xWAkhEn}sz zbzqFe$k)PV=?0R>TCl)k@xWrCWCO|m0WiQ~joe~jb_2s` zl7Qj{lIV4S%%WMhAYZY8WbYdG?iV+Z#8(3{i)Chu``6e&lI{glEEY+xVaixLrj2O@W literal 0 HcmV?d00001 diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index ce404911..67875a71 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -29,6 +29,7 @@ from pruna.data.datasets.prompt import ( setup_drawbench_dataset, setup_genai_bench_dataset, + setup_oneig_text_rendering_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -100,6 +101,7 @@ "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "OneIGTextRendering": (setup_oneig_text_rendering_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } @@ -213,6 +215,13 @@ class BenchmarkInfo: metrics=["perplexity"], task_type="text_generation", ), + "OneIGTextRendering": BenchmarkInfo( + name="oneig_text_rendering", + display_name="OneIG Text Rendering", + description="Evaluates text rendering quality in generated images using OCR-based metrics.", + metrics=["clip_score", "clipiqa"], + task_type="text_to_image", + ), } diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 4f275675..7c602314 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -104,3 +104,50 @@ def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ds = ds.rename_column("Prompt", "text") pruna_logger.info("GenAI-Bench is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds + + +def setup_oneig_text_rendering_dataset( + seed: int, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the OneIG Text Rendering benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The OneIG Text Rendering dataset (dummy train, dummy val, test). + """ + import csv + import io + + import requests + + url = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/main/benchmark/text_rendering.csv" + response = requests.get(url) + reader = csv.DictReader(io.StringIO(response.text)) + + records = [] + for row in reader: + records.append({ + "text": row.get("prompt", ""), + "text_content": row.get("text_content", row.get("text", "")), + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("OneIG Text Rendering is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 61550698..40ba2cb0 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -45,6 +45,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), + pytest.param("OneIGTextRendering", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -96,3 +97,18 @@ def test_parti_prompts_with_category_filter(): assert len(prompts) == 4 assert all(isinstance(p, str) for p in prompts) assert all(aux["Category"] == "Animals" for aux in auxiliaries) + + +@pytest.mark.slow +def test_oneig_text_rendering_auxiliaries(): + """Test OneIGTextRendering loading with auxiliaries.""" + dm = PrunaDataModule.from_string( + "OneIGTextRendering", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all("text_content" in aux for aux in auxiliaries) From 343eeb1c25c983abcd492a0703848ac85034298c Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:37:35 +0100 Subject: [PATCH 7/9] feat: add OneIG Alignment benchmark - Add setup_oneig_alignment_dataset in datasets/prompt.py - Support category filter (Anime_Stylization, Portrait, General_Object) - Register OneIGAlignment in base_datasets - Add BenchmarkInfo entry with accuracy metric, task_type text_generation - Auxiliaries include questions, dependencies, category - Add test for loading with category filter Co-authored-by: Cursor --- src/pruna/data/__init__.py | 14 ++++++- src/pruna/data/datasets/prompt.py | 68 +++++++++++++++++++++++++++++++ tests/data/test_datamodule.py | 17 ++++++++ 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 67875a71..e48d2608 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -29,6 +29,7 @@ from pruna.data.datasets.prompt import ( setup_drawbench_dataset, setup_genai_bench_dataset, + setup_oneig_alignment_dataset, setup_oneig_text_rendering_dataset, setup_parti_prompts_dataset, ) @@ -102,6 +103,7 @@ "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), "OneIGTextRendering": (setup_oneig_text_rendering_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGAlignment": (setup_oneig_alignment_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } @@ -219,8 +221,16 @@ class BenchmarkInfo: name="oneig_text_rendering", display_name="OneIG Text Rendering", description="Evaluates text rendering quality in generated images using OCR-based metrics.", - metrics=["clip_score", "clipiqa"], - task_type="text_to_image", + metrics=["accuracy"], + task_type="text_generation", + ), + "OneIGAlignment": BenchmarkInfo( + name="oneig_alignment", + display_name="OneIG Alignment", + description="Evaluates image-text alignment for anime, human, and object generation with VQA-based questions.", + metrics=["accuracy"], + task_type="text_generation", + subsets=["Anime_Stylization", "Portrait", "General_Object"], ), } diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 7c602314..6cceb0d7 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -151,3 +151,71 @@ def setup_oneig_text_rendering_dataset( pruna_logger.info("OneIG Text Rendering is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds + + +ONEIG_ALIGNMENT_CATEGORIES = ["Anime_Stylization", "Portrait", "General_Object"] + + +def setup_oneig_alignment_dataset( + seed: int, + category: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the OneIG Alignment benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + category : str | None + Filter by category. Available: Anime_Stylization, Portrait, General_Object. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The OneIG Alignment dataset (dummy train, dummy val, test). + """ + import json + + import requests + + ds = load_dataset("OneIG-Bench/OneIG-Bench")["test"] # type: ignore[index] + + url = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/main/benchmark/alignment_questions.json" + response = requests.get(url) + questions_data = json.loads(response.text) + + questions_by_id = {q["id"]: q for q in questions_data} + + records = [] + for row in ds: + row_id = row.get("id", "") + row_category = row.get("category", "") + + if category is not None: + if category not in ONEIG_ALIGNMENT_CATEGORIES: + raise ValueError(f"Invalid category: {category}. Must be one of {ONEIG_ALIGNMENT_CATEGORIES}") + if row_category != category: + continue + + q_info = questions_by_id.get(row_id, {}) + records.append({ + "text": row.get("prompt", ""), + "category": row_category, + "questions": q_info.get("questions", []), + "dependencies": q_info.get("dependencies", []), + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("OneIG Alignment is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 40ba2cb0..72e79681 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -46,6 +46,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), pytest.param("OneIGTextRendering", dict(), marks=pytest.mark.slow), + pytest.param("OneIGAlignment", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -112,3 +113,19 @@ def test_oneig_text_rendering_auxiliaries(): assert len(prompts) == 4 assert all(isinstance(p, str) for p in prompts) assert all("text_content" in aux for aux in auxiliaries) + + +@pytest.mark.slow +def test_oneig_alignment_with_category_filter(): + """Test OneIGAlignment loading with category filter.""" + dm = PrunaDataModule.from_string( + "OneIGAlignment", category="Portrait", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["category"] == "Portrait" for aux in auxiliaries) + assert all("questions" in aux for aux in auxiliaries) From 655039dc158114791bd8e17d859b112664b16f13 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:40:00 +0100 Subject: [PATCH 8/9] feat: add DPG (Descriptive Prompt Generation) benchmark - Add setup_dpg_dataset in datasets/prompt.py - Support category filter (entity, attribute, relation, global, other) - Register DPG in base_datasets - Add BenchmarkInfo entry with accuracy metric, task_type text_generation - Auxiliaries include questions, category_broad - Add test for loading with category filter Co-authored-by: Cursor --- src/pruna/data/__init__.py | 10 +++++ src/pruna/data/datasets/prompt.py | 62 +++++++++++++++++++++++++++++++ tests/data/test_datamodule.py | 17 +++++++++ 3 files changed, 89 insertions(+) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index e48d2608..4ed3decf 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -27,6 +27,7 @@ setup_mnist_dataset, ) from pruna.data.datasets.prompt import ( + setup_dpg_dataset, setup_drawbench_dataset, setup_genai_bench_dataset, setup_oneig_alignment_dataset, @@ -104,6 +105,7 @@ "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), "OneIGTextRendering": (setup_oneig_text_rendering_dataset, "prompt_with_auxiliaries_collate", {}), "OneIGAlignment": (setup_oneig_alignment_dataset, "prompt_with_auxiliaries_collate", {}), + "DPG": (setup_dpg_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } @@ -232,6 +234,14 @@ class BenchmarkInfo: task_type="text_generation", subsets=["Anime_Stylization", "Portrait", "General_Object"], ), + "DPG": BenchmarkInfo( + name="dpg", + display_name="DPG", + description="Descriptive Prompt Generation benchmark for evaluating image understanding across entity, attribute, relation, and global aspects.", + metrics=["accuracy"], + task_type="text_generation", + subsets=["entity", "attribute", "relation", "global", "other"], + ), } diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 6cceb0d7..c5932449 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -219,3 +219,65 @@ def setup_oneig_alignment_dataset( pruna_logger.info("OneIG Alignment is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds + + +DPG_CATEGORIES = ["entity", "attribute", "relation", "global", "other"] + + +def setup_dpg_dataset( + seed: int, + category: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the DPG (Descriptive Prompt Generation) benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + category : str | None + Filter by category. Available: entity, attribute, relation, global, other. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The DPG dataset (dummy train, dummy val, test). + """ + import csv + import io + + import requests + + url = "https://raw.githubusercontent.com/TencentQQGYLab/ELLA/main/dpg_bench/prompts.csv" + response = requests.get(url) + reader = csv.DictReader(io.StringIO(response.text)) + + records = [] + for row in reader: + row_category = row.get("category", row.get("category_broad", "")) + + if category is not None: + if category not in DPG_CATEGORIES: + raise ValueError(f"Invalid category: {category}. Must be one of {DPG_CATEGORIES}") + if row_category != category: + continue + + records.append({ + "text": row.get("prompt", ""), + "category_broad": row_category, + "questions": row.get("questions", "").split("|") if row.get("questions") else [], + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("DPG is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 72e79681..0255c07e 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -47,6 +47,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("VBench", dict(), marks=pytest.mark.slow), pytest.param("OneIGTextRendering", dict(), marks=pytest.mark.slow), pytest.param("OneIGAlignment", dict(), marks=pytest.mark.slow), + pytest.param("DPG", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -129,3 +130,19 @@ def test_oneig_alignment_with_category_filter(): assert all(isinstance(p, str) for p in prompts) assert all(aux["category"] == "Portrait" for aux in auxiliaries) assert all("questions" in aux for aux in auxiliaries) + + +@pytest.mark.slow +def test_dpg_with_category_filter(): + """Test DPG loading with category filter.""" + dm = PrunaDataModule.from_string( + "DPG", category="entity", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["category_broad"] == "entity" for aux in auxiliaries) + assert all("questions" in aux for aux in auxiliaries) From 03c7a83d48e15b03b81e61060422b5578008476e Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 31 Jan 2026 16:59:00 +0100 Subject: [PATCH 9/9] fix: address review feedback for DPG benchmark - Fix task_type from text_generation to text_to_image for DPG, OneIGAlignment, and OneIGTextRendering - Remove unused imports in test file Co-authored-by: Cursor --- src/pruna/data/__init__.py | 11 +++++++---- tests/data/test_datamodule.py | 1 - 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 4ed3decf..420be42e 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -224,22 +224,25 @@ class BenchmarkInfo: display_name="OneIG Text Rendering", description="Evaluates text rendering quality in generated images using OCR-based metrics.", metrics=["accuracy"], - task_type="text_generation", + task_type="text_to_image", ), "OneIGAlignment": BenchmarkInfo( name="oneig_alignment", display_name="OneIG Alignment", description="Evaluates image-text alignment for anime, human, and object generation with VQA-based questions.", metrics=["accuracy"], - task_type="text_generation", + task_type="text_to_image", subsets=["Anime_Stylization", "Portrait", "General_Object"], ), "DPG": BenchmarkInfo( name="dpg", display_name="DPG", - description="Descriptive Prompt Generation benchmark for evaluating image understanding across entity, attribute, relation, and global aspects.", + description=( + "Descriptive Prompt Generation benchmark for evaluating image understanding " + "across entity, attribute, relation, and global aspects." + ), metrics=["accuracy"], - task_type="text_generation", + task_type="text_to_image", subsets=["entity", "attribute", "relation", "global", "other"], ), } diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 0255c07e..172dc95b 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -4,7 +4,6 @@ import torch from transformers import AutoTokenizer -from pruna.data import BenchmarkInfo, benchmark_info from pruna.data.datasets.image import setup_imagenet_dataset from pruna.data.pruna_datamodule import PrunaDataModule