feat: add benchmark support to PrunaDataModule and implement PartiPrompts#502
feat: add benchmark support to PrunaDataModule and implement PartiPrompts#502davidberenstein1957 wants to merge 5 commits intomainfrom
Conversation
…mpts benchmark - Introduced `from_benchmark` method in `PrunaDataModule` to create instances from benchmark classes. - Added `Benchmark`, `BenchmarkEntry`, and `BenchmarkRegistry` classes for managing benchmarks. - Implemented `PartiPrompts` benchmark for text-to-image generation with various categories and challenges. - Created utility function `benchmark_to_datasets` to convert benchmarks into datasets compatible with `PrunaDataModule`. - Added integration tests for benchmark functionality and data module interactions.
…filtering - Remove heavy benchmark abstraction (Benchmark class, registry, adapter, 24 subclasses) - Extend setup_parti_prompts_dataset with category and num_samples params - Add BenchmarkInfo dataclass for metadata (metrics, description, subsets) - Switch PartiPrompts to prompt_with_auxiliaries_collate to preserve Category/Challenge - Merge tests into test_datamodule.py Reduces 964 lines to 128 lines (87% reduction) Co-authored-by: Cursor <cursoragent@cursor.com>
Document all dataclass fields per Numpydoc PR01 with summary on new line per GL01. Co-authored-by: Cursor <cursoragent@cursor.com>
- Add list_benchmarks() to filter benchmarks by task type - Add get_benchmark_info() to retrieve benchmark metadata - Add COCO, ImageNet, WikiText to benchmark_info registry Co-authored-by: Cursor <cursoragent@cursor.com>
Update benchmark metrics to match registered names: - clip -> clip_score - clip_iqa -> clipiqa - Remove unimplemented top5_accuracy Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 3 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
| ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] | ||
|
|
||
| if category is not None: | ||
| ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category) |
There was a problem hiding this comment.
Category filter fails silently when list is passed
Medium Severity
The setup_parti_prompts_dataset function's category parameter only accepts str | None, but PrunaDataModule.from_string accepts category: str | list[str] | None. When a list is passed, the filter x["Category"] == category or x["Challenge"] == category compares strings against a list, which always evaluates to False. This silently filters out all records, resulting in an empty dataset that causes ds.select([0]) to fail with an index error.
| import torch | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from pruna.data import BenchmarkInfo, benchmark_info |
| if name not in benchmark_info: | ||
| available = ", ".join(benchmark_info.keys()) | ||
| raise KeyError(f"Benchmark '{name}' not found. Available: {available}") | ||
| return benchmark_info[name] |
There was a problem hiding this comment.
begumcig
left a comment
There was a problem hiding this comment.
Really nice David! I think we shouldn't tightly couple metrics with datasets under benchmark, but otherwise it all looks good to me! Thank you!
| if category is not None: | ||
| ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category) | ||
|
|
||
| ds = ds.shuffle(seed=seed) |
There was a problem hiding this comment.
since we are only creating test set how do you feel about not shuffling the data?
| "ranging from basic to complex, enabling comprehensive assessment of model capabilities " | ||
| "across different domains and difficulty levels." | ||
| ), | ||
| metrics=["arniqa", "clip_score", "clipiqa", "sharpness"], |
There was a problem hiding this comment.
I don't think we should tightly couple benchmarking datasets with metrics. I think benchmarks should have their datasets available as PrunaDataModules, and the metrics for the Benchmarks should be Pruna Metrics. This way we can give the user the flexibility to use whichever dataset with whichever metric they choose, how do you feel?


from_benchmarkmethod inPrunaDataModuleto create instances from benchmark classes.Benchmark,BenchmarkEntry, andBenchmarkRegistryclasses for managing benchmarks.PartiPromptsbenchmark for text-to-image generation with various categories and challenges.benchmark_to_datasetsto convert benchmarks into datasets compatible withPrunaDataModule.Description
Related Issue
Fixes #(issue number)
Type of Change
How Has This Been Tested?
Checklist
Additional Notes