diff --git a/ [conflicted 2].coverage b/ [conflicted 2].coverage new file mode 100644 index 00000000..5621a052 Binary files /dev/null and b/ [conflicted 2].coverage differ diff --git a/ [conflicted 3].coverage b/ [conflicted 3].coverage new file mode 100644 index 00000000..8d56fae2 Binary files /dev/null and b/ [conflicted 3].coverage differ diff --git a/ [conflicted 4].coverage b/ [conflicted 4].coverage new file mode 100644 index 00000000..b576ec23 Binary files /dev/null and b/ [conflicted 4].coverage differ diff --git a/ [conflicted 5].coverage b/ [conflicted 5].coverage new file mode 100644 index 00000000..585f8543 Binary files /dev/null and b/ [conflicted 5].coverage differ diff --git a/ [conflicted].coverage b/ [conflicted].coverage new file mode 100644 index 00000000..17bc6ca6 Binary files /dev/null and b/ [conflicted].coverage differ diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 3a811868..420be42e 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass, field from functools import partial from typing import Any, Callable, Tuple @@ -26,8 +27,11 @@ setup_mnist_dataset, ) from pruna.data.datasets.prompt import ( + setup_dpg_dataset, setup_drawbench_dataset, setup_genai_bench_dataset, + setup_oneig_alignment_dataset, + setup_oneig_text_rendering_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -97,8 +101,193 @@ {"img_size": 224}, ), "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), - "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), + "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "OneIGTextRendering": (setup_oneig_text_rendering_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGAlignment": (setup_oneig_alignment_dataset, "prompt_with_auxiliaries_collate", {}), + "DPG": (setup_dpg_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), } + + +@dataclass +class BenchmarkInfo: + """ + Metadata for a benchmark dataset. + + Parameters + ---------- + name : str + Internal identifier for the benchmark. + display_name : str + Human-readable name for display purposes. + description : str + Description of what the benchmark evaluates. + metrics : list[str] + List of metric names used for evaluation. + task_type : str + Type of task the benchmark evaluates (e.g., 'text_to_image'). + subsets : list[str] + Optional list of benchmark subset names. + """ + + name: str + display_name: str + description: str + metrics: list[str] + task_type: str + subsets: list[str] = field(default_factory=list) + + +benchmark_info: dict[str, BenchmarkInfo] = { + "PartiPrompts": BenchmarkInfo( + name="parti_prompts", + display_name="Parti Prompts", + description=( + "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " + "ranging from basic to complex, enabling comprehensive assessment of model capabilities " + "across different domains and difficulty levels." + ), + metrics=["arniqa", "clip_score", "clipiqa", "sharpness"], + task_type="text_to_image", + subsets=[ + "Abstract", + "Animals", + "Artifacts", + "Arts", + "Food & Beverage", + "Illustrations", + "Indoor Scenes", + "Outdoor Scenes", + "People", + "Produce & Plants", + "Vehicles", + "World Knowledge", + "Basic", + "Complex", + "Fine-grained Detail", + "Imagination", + "Linguistic Structures", + "Perspective", + "Properties & Positioning", + "Quantity", + "Simple Detail", + "Style & Format", + "Writing & Symbols", + ], + ), + "DrawBench": BenchmarkInfo( + name="drawbench", + display_name="DrawBench", + description="A comprehensive benchmark for evaluating text-to-image generation models.", + metrics=["clip_score", "clipiqa", "sharpness"], + task_type="text_to_image", + ), + "GenAIBench": BenchmarkInfo( + name="genai_bench", + display_name="GenAI Bench", + description="A benchmark for evaluating generative AI models.", + metrics=["clip_score", "clipiqa", "sharpness"], + task_type="text_to_image", + ), + "VBench": BenchmarkInfo( + name="vbench", + display_name="VBench", + description="A benchmark for evaluating video generation models.", + metrics=["clip_score"], + task_type="text_to_video", + ), + "COCO": BenchmarkInfo( + name="coco", + display_name="COCO", + description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", + metrics=["fid", "clip_score", "clipiqa"], + task_type="text_to_image", + ), + "ImageNet": BenchmarkInfo( + name="imagenet", + display_name="ImageNet", + description="Large-scale image classification benchmark with 1000 classes.", + metrics=["accuracy"], + task_type="image_classification", + ), + "WikiText": BenchmarkInfo( + name="wikitext", + display_name="WikiText", + description="Language modeling benchmark based on Wikipedia articles.", + metrics=["perplexity"], + task_type="text_generation", + ), + "OneIGTextRendering": BenchmarkInfo( + name="oneig_text_rendering", + display_name="OneIG Text Rendering", + description="Evaluates text rendering quality in generated images using OCR-based metrics.", + metrics=["accuracy"], + task_type="text_to_image", + ), + "OneIGAlignment": BenchmarkInfo( + name="oneig_alignment", + display_name="OneIG Alignment", + description="Evaluates image-text alignment for anime, human, and object generation with VQA-based questions.", + metrics=["accuracy"], + task_type="text_to_image", + subsets=["Anime_Stylization", "Portrait", "General_Object"], + ), + "DPG": BenchmarkInfo( + name="dpg", + display_name="DPG", + description=( + "Descriptive Prompt Generation benchmark for evaluating image understanding " + "across entity, attribute, relation, and global aspects." + ), + metrics=["accuracy"], + task_type="text_to_image", + subsets=["entity", "attribute", "relation", "global", "other"], + ), +} + + +def list_benchmarks(task_type: str | None = None) -> list[str]: + """ + List available benchmark names. + + Parameters + ---------- + task_type : str | None + Filter by task type (e.g., 'text_to_image', 'text_to_video'). + If None, returns all benchmarks. + + Returns + ------- + list[str] + List of benchmark names. + """ + if task_type is None: + return list(benchmark_info.keys()) + return [name for name, info in benchmark_info.items() if info.task_type == task_type] + + +def get_benchmark_info(name: str) -> BenchmarkInfo: + """ + Get benchmark metadata by name. + + Parameters + ---------- + name : str + The benchmark name. + + Returns + ------- + BenchmarkInfo + The benchmark metadata. + + Raises + ------ + KeyError + If benchmark name is not found. + """ + if name not in benchmark_info: + available = ", ".join(benchmark_info.keys()) + raise KeyError(f"Benchmark '{name}' not found. Available: {available}") + return benchmark_info[name] diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 1f6fab71..c5932449 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -41,7 +41,11 @@ def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: return ds.select([0]), ds.select([0]), ds -def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def setup_parti_prompts_dataset( + seed: int, + category: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the Parti Prompts dataset. @@ -51,13 +55,30 @@ def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ---------- seed : int The seed to use. + category : str | None + Filter by Category or Challenge. Available categories: Abstract, Animals, Artifacts, + Arts, Food & Beverage, Illustrations, Indoor Scenes, Outdoor Scenes, People, + Produce & Plants, Vehicles, World Knowledge. Available challenges: Basic, Complex, + Fine-grained Detail, Imagination, Linguistic Structures, Perspective, + Properties & Positioning, Quantity, Simple Detail, Style & Format, Writing & Symbols. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. Returns ------- Tuple[Dataset, Dataset, Dataset] - The Parti Prompts dataset. + The Parti Prompts dataset (dummy train, dummy val, test). """ ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] + + if category is not None: + ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category) + + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + ds = ds.rename_column("Prompt", "text") pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds @@ -83,3 +104,180 @@ def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ds = ds.rename_column("Prompt", "text") pruna_logger.info("GenAI-Bench is a test-only dataset. Do not use it for training or validation.") return ds.select([0]), ds.select([0]), ds + + +def setup_oneig_text_rendering_dataset( + seed: int, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the OneIG Text Rendering benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The OneIG Text Rendering dataset (dummy train, dummy val, test). + """ + import csv + import io + + import requests + + url = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/main/benchmark/text_rendering.csv" + response = requests.get(url) + reader = csv.DictReader(io.StringIO(response.text)) + + records = [] + for row in reader: + records.append({ + "text": row.get("prompt", ""), + "text_content": row.get("text_content", row.get("text", "")), + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("OneIG Text Rendering is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +ONEIG_ALIGNMENT_CATEGORIES = ["Anime_Stylization", "Portrait", "General_Object"] + + +def setup_oneig_alignment_dataset( + seed: int, + category: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the OneIG Alignment benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + category : str | None + Filter by category. Available: Anime_Stylization, Portrait, General_Object. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The OneIG Alignment dataset (dummy train, dummy val, test). + """ + import json + + import requests + + ds = load_dataset("OneIG-Bench/OneIG-Bench")["test"] # type: ignore[index] + + url = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/main/benchmark/alignment_questions.json" + response = requests.get(url) + questions_data = json.loads(response.text) + + questions_by_id = {q["id"]: q for q in questions_data} + + records = [] + for row in ds: + row_id = row.get("id", "") + row_category = row.get("category", "") + + if category is not None: + if category not in ONEIG_ALIGNMENT_CATEGORIES: + raise ValueError(f"Invalid category: {category}. Must be one of {ONEIG_ALIGNMENT_CATEGORIES}") + if row_category != category: + continue + + q_info = questions_by_id.get(row_id, {}) + records.append({ + "text": row.get("prompt", ""), + "category": row_category, + "questions": q_info.get("questions", []), + "dependencies": q_info.get("dependencies", []), + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("OneIG Alignment is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +DPG_CATEGORIES = ["entity", "attribute", "relation", "global", "other"] + + +def setup_dpg_dataset( + seed: int, + category: str | None = None, + num_samples: int | None = None, +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the DPG (Descriptive Prompt Generation) benchmark dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + category : str | None + Filter by category. Available: entity, attribute, relation, global, other. + num_samples : int | None + Maximum number of samples to return. If None, returns all samples. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The DPG dataset (dummy train, dummy val, test). + """ + import csv + import io + + import requests + + url = "https://raw.githubusercontent.com/TencentQQGYLab/ELLA/main/dpg_bench/prompts.csv" + response = requests.get(url) + reader = csv.DictReader(io.StringIO(response.text)) + + records = [] + for row in reader: + row_category = row.get("category", row.get("category_broad", "")) + + if category is not None: + if category not in DPG_CATEGORIES: + raise ValueError(f"Invalid category: {category}. Must be one of {DPG_CATEGORIES}") + if row_category != category: + continue + + records.append({ + "text": row.get("prompt", ""), + "category_broad": row_category, + "questions": row.get("questions", "").split("|") if row.get("questions") else [], + }) + + ds = Dataset.from_list(records) + ds = ds.shuffle(seed=seed) + + if num_samples is not None: + ds = ds.select(range(min(num_samples, len(ds)))) + + pruna_logger.info("DPG is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 04d226f6..172dc95b 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,10 +1,9 @@ from typing import Any, Callable import pytest -from transformers import AutoTokenizer -from datasets import Dataset -from torch.utils.data import TensorDataset import torch +from transformers import AutoTokenizer + from pruna.data.datasets.image import setup_imagenet_dataset from pruna.data.pruna_datamodule import PrunaDataModule @@ -45,6 +44,9 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), + pytest.param("OneIGTextRendering", dict(), marks=pytest.mark.slow), + pytest.param("OneIGAlignment", dict(), marks=pytest.mark.slow), + pytest.param("DPG", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: @@ -80,3 +82,66 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: Callable, collate_fn_ar assert labels.dtype == torch.int64 # iterate through the dataloaders iterate_dataloaders(datamodule) + + + +@pytest.mark.slow +def test_parti_prompts_with_category_filter(): + """Test PartiPrompts loading with category filter.""" + dm = PrunaDataModule.from_string( + "PartiPrompts", category="Animals", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["Category"] == "Animals" for aux in auxiliaries) + + +@pytest.mark.slow +def test_oneig_text_rendering_auxiliaries(): + """Test OneIGTextRendering loading with auxiliaries.""" + dm = PrunaDataModule.from_string( + "OneIGTextRendering", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all("text_content" in aux for aux in auxiliaries) + + +@pytest.mark.slow +def test_oneig_alignment_with_category_filter(): + """Test OneIGAlignment loading with category filter.""" + dm = PrunaDataModule.from_string( + "OneIGAlignment", category="Portrait", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["category"] == "Portrait" for aux in auxiliaries) + assert all("questions" in aux for aux in auxiliaries) + + +@pytest.mark.slow +def test_dpg_with_category_filter(): + """Test DPG loading with category filter.""" + dm = PrunaDataModule.from_string( + "DPG", category="entity", dataloader_args={"batch_size": 4} + ) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(aux["category_broad"] == "entity" for aux in auxiliaries) + assert all("questions" in aux for aux in auxiliaries)