Skip to content

feat: add benchmark support to PrunaDataModule and implement PartiPrompts#502

Open
davidberenstein1957 wants to merge 5 commits intomainfrom
feat/add-partiprompts-benchmark-to-pruna
Open

feat: add benchmark support to PrunaDataModule and implement PartiPrompts#502
davidberenstein1957 wants to merge 5 commits intomainfrom
feat/add-partiprompts-benchmark-to-pruna

Conversation

@davidberenstein1957
Copy link
Member

  • Introduced from_benchmark method in PrunaDataModule to create instances from benchmark classes.
  • Added Benchmark, BenchmarkEntry, and BenchmarkRegistry classes for managing benchmarks.
  • Implemented PartiPrompts benchmark for text-to-image generation with various categories and challenges.
  • Created utility function benchmark_to_datasets to convert benchmarks into datasets compatible with PrunaDataModule.
  • Added integration tests for benchmark functionality and data module interactions.

Description

Related Issue

Fixes #(issue number)

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

…mpts benchmark

- Introduced `from_benchmark` method in `PrunaDataModule` to create instances from benchmark classes.
- Added `Benchmark`, `BenchmarkEntry`, and `BenchmarkRegistry` classes for managing benchmarks.
- Implemented `PartiPrompts` benchmark for text-to-image generation with various categories and challenges.
- Created utility function `benchmark_to_datasets` to convert benchmarks into datasets compatible with `PrunaDataModule`.
- Added integration tests for benchmark functionality and data module interactions.
davidberenstein1957 and others added 4 commits January 31, 2026 15:50
…filtering

- Remove heavy benchmark abstraction (Benchmark class, registry, adapter, 24 subclasses)
- Extend setup_parti_prompts_dataset with category and num_samples params
- Add BenchmarkInfo dataclass for metadata (metrics, description, subsets)
- Switch PartiPrompts to prompt_with_auxiliaries_collate to preserve Category/Challenge
- Merge tests into test_datamodule.py

Reduces 964 lines to 128 lines (87% reduction)

Co-authored-by: Cursor <cursoragent@cursor.com>
Document all dataclass fields per Numpydoc PR01 with summary on new line per GL01.

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add list_benchmarks() to filter benchmarks by task type
- Add get_benchmark_info() to retrieve benchmark metadata
- Add COCO, ImageNet, WikiText to benchmark_info registry

Co-authored-by: Cursor <cursoragent@cursor.com>
Update benchmark metrics to match registered names:
- clip -> clip_score
- clip_iqa -> clipiqa
- Remove unimplemented top5_accuracy

Co-authored-by: Cursor <cursoragent@cursor.com>
@davidberenstein1957 davidberenstein1957 marked this pull request as ready for review February 2, 2026 04:24
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 3 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index]

if category is not None:
ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Category filter fails silently when list is passed

Medium Severity

The setup_parti_prompts_dataset function's category parameter only accepts str | None, but PrunaDataModule.from_string accepts category: str | list[str] | None. When a list is passed, the filter x["Category"] == category or x["Challenge"] == category compares strings against a list, which always evaluates to False. This silently filters out all records, resulting in an empty dataset that causes ds.select([0]) to fail with an index error.

Fix in Cursor Fix in Web

import torch
from transformers import AutoTokenizer

from pruna.data import BenchmarkInfo, benchmark_info
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports in test file

Low Severity

BenchmarkInfo and benchmark_info are imported but never used in the test file. These imports should be removed.

Fix in Cursor Fix in Web

if name not in benchmark_info:
available = ", ".join(benchmark_info.keys())
raise KeyError(f"Benchmark '{name}' not found. Available: {available}")
return benchmark_info[name]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused functions defined but never called

Medium Severity

list_benchmarks() and get_benchmark_info() are defined but never called anywhere in the codebase. The PR description mentions a from_benchmark method that would presumably use these, but it's not implemented in this PR.

Fix in Cursor Fix in Web

Copy link
Member

@begumcig begumcig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice David! I think we shouldn't tightly couple metrics with datasets under benchmark, but otherwise it all looks good to me! Thank you!

if category is not None:
ds = ds.filter(lambda x: x["Category"] == category or x["Challenge"] == category)

ds = ds.shuffle(seed=seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are only creating test set how do you feel about not shuffling the data?

"ranging from basic to complex, enabling comprehensive assessment of model capabilities "
"across different domains and difficulty levels."
),
metrics=["arniqa", "clip_score", "clipiqa", "sharpness"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should tightly couple benchmarking datasets with metrics. I think benchmarks should have their datasets available as PrunaDataModules, and the metrics for the Benchmarks should be Pruna Metrics. This way we can give the user the flexibility to use whichever dataset with whichever metric they choose, how do you feel?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants