Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions src/pruna/algorithms/padding_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import inspect
from collections.abc import Iterable
from typing import Any

from ConfigSpace import OrdinalHyperparameter

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.engine.model_checks import is_diffusers_model
from pruna.engine.save import SAVE_FUNCTIONS


class PaddingPruner(PrunaAlgorithmBase):
"""
Implement Padding Pruning for Diffusers pipelines.

Padding Pruning removes unused padding tokens from the prompt embedding of diffusers pipelines.
"""

algorithm_name: str = "padding_pruning"
group_tags: list[AlgorithmTag] = [AlgorithmTag.PRUNER]
references: dict[str, str] = {}
tokenizer_required: bool = True
processor_required: bool = False
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
dataset_required: bool = False
save_fn = SAVE_FUNCTIONS.reapply
compatible_before: Iterable[str | AlgorithmTag] = ["qkv_diffusers"]
compatible_after: Iterable[str | AlgorithmTag] = [
AlgorithmTag.CACHER,
"hyper",
"torch_compile",
"stable_fast",
"hqq_diffusers",
"diffusers_int8",
"torchao",
"flash_attn3",
"ring_attn",
]

def get_hyperparameters(self) -> list:
"""
Get the hyperparameters for the Prompt Pruner.

Returns
-------
list
A list of hyperparameters.
"""
return [
OrdinalHyperparameter(
"min_sequence_length",
sequence=[32, 64, 128, 256],
default_value=64,
meta=dict(desc="Minimum sequence length used to embed a prompt."),
),
]

def model_check_fn(self, model: Any) -> bool:
"""
Check if the model is a diffusers pipeline with a max_sequence_length parameter.

Parameters
----------
model : Any
The model instance to check.

Returns
-------
bool
True if the model is a diffusers pipeline with a max_sequence_length parameter.
"""
if not is_diffusers_model(model):
return False
signature = inspect.signature(model.__call__)
return "max_sequence_length" in signature.parameters

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Apply Prompt Pruning to the pipeline.

Parameters
----------
model : Any
The pipeline to apply prompt pruning to.
smash_config : SmashConfigPrefixWrapper
Configuration settings for the pruning.

Returns
-------
Any
The pipeline with Prompt Pruning enabled.
"""
min_sequence_length = smash_config["min_sequence_length"]
model.padding_pruning_helper = PaddingPruningHelper(model, min_sequence_length, smash_config.tokenizer)
model.padding_pruning_helper.enable()
return model

def import_algorithm_packages(self) -> dict[str, Any]:
"""
Import necessary algorithm packages.

Returns
-------
dict
An empty dictionary as no packages are imported in this implementation.
"""
return dict()


class PaddingPruningHelper:
"""
Helper for Padding Pruning.

Parameters
----------
pipe : Any
The diffusers pipeline to wrap.
min_tokens : int
The minimum number of tokens to embed a prompt.
tokenizer : Any
The tokenizer of the pipeline.
"""

def __init__(self, pipe: Any, min_tokens: int, tokenizer: Any) -> None:
self.pipe = pipe
self.min_tokens = min_tokens
self.tokenizer = tokenizer

def enable(self) -> None:
"""Enable prompt pruning by wrapping the pipe."""
self.wrap_pipe(self.pipe)

def disable(self) -> None:
"""Disable prompt pruning by unwrapping the pipe."""
if self.pipe_call:
self.pipe.__call__ = self.pipe_call

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Helper disable may crash before enable

disable accesses self.pipe_call without initializing it in __init__. If disable is called before enable/wrap_pipe runs (or after partial initialization failures), this can raise AttributeError and prevent cleanup/saving flows that call .disable() on helpers.

Fix in Cursor Fix in Web

def wrap_pipe(self, pipe: Any) -> None:
"""
Wrap the call method of the pipe to adjust the max sequence length.

Parameters
----------
pipe : Any
The diffusers pipeline to wrap.
"""
pipe_call = pipe.__call__
self.pipe_call = pipe_call
signature = inspect.signature(pipe_call)
default_max_sequence_length = signature.parameters["max_sequence_length"].default

@functools.wraps(pipe_call)
def wrapped_call(*args, **kwargs): # noqa: ANN201
# while a natural approach would be to remove all padding tokens,
# we found this to degrade the quality of the generated images
# for this reason, we usually round to the nearest order of two
# and use this as the max sequence length

# the min_tokens parameter controls the minimum for the max sequence length
min_sequence_length = self.min_tokens
# we use the default value as the maximum value for the max sequence length
max_sequence_length = kwargs.get("max_sequence_length", default_max_sequence_length)

prompts = self._extract_prompts(args, kwargs)
max_num_tokens = max(len(self.tokenizer.encode(p)) for p in prompts)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Fails when prompts are not strings

wrapped_call computes max_num_tokens via max(... for p in prompts), but _extract_prompts can return an empty list (e.g., callers pass prompt_embeds/no prompt strings). This triggers a runtime ValueError and breaks inference for valid diffusers call patterns.

Fix in Cursor Fix in Web

sequence_length = min_sequence_length
while max_num_tokens > sequence_length:
sequence_length *= 2
if sequence_length >= max_sequence_length:
sequence_length = max_sequence_length
kwargs["max_sequence_length"] = sequence_length
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Assumes max_sequence_length has a default

default_max_sequence_length is read from inspect.signature(...).parameters["max_sequence_length"].default. If that default is inspect._empty (or the user passes max_sequence_length=None), comparisons like sequence_length >= max_sequence_length can raise TypeError, causing runtime failures on pipelines with a required max_sequence_length.

Fix in Cursor Fix in Web

return pipe_call(*args, **kwargs)

pipe.__call__ = wrapped_call

def _extract_prompts(self, args: Any, kwargs: Any) -> list[str]:
"""Extract the prompts from the args and kwargs of the pipe call."""
prompts: list[str] = []

# the first arguments of diffusers pipelines are usually the prompts
for arg in args:
if isinstance(arg, str):
prompts.append(arg)
elif isinstance(arg, list):
if len(arg) > 0 and isinstance(arg[0], str):
prompts.extend(arg)
else:
break

for kwarg in kwargs:
if kwarg.startswith("prompt"):
prompt = kwargs[kwarg]
if isinstance(prompt, str):
prompts.append(prompt)
elif isinstance(prompt, list):
prompts.extend(prompt)
if kwarg.startswith("negative_prompt"):
negative_prompt = kwargs[kwarg]
if isinstance(negative_prompt, str):
prompts.append(negative_prompt)
elif isinstance(negative_prompt, list):
prompts.extend(negative_prompt)
return prompts
22 changes: 22 additions & 0 deletions tests/algorithms/testers/padding_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pruna.algorithms.padding_pruning import PaddingPruner
from pruna.engine.pruna_model import PrunaModel

from .base_tester import AlgorithmTesterBase


class TestPaddingPruning(AlgorithmTesterBase):
"""Test the padding pruning algorithm."""

models = ["flux_tiny_random_with_tokenizer"]
reject_models = ["opt_tiny_random"]
allow_pickle_files = False
algorithm_class = PaddingPruner
metrics = ["cmmd"]

def post_smash_hook(self, model: PrunaModel) -> None:
"""Hook to modify the model after smashing."""
assert hasattr(model, "padding_pruning_helper")
model.text_encoder.resize_token_embeddings(model.smash_config.tokenizer.vocab_size)

if hasattr(model, "text_encoder_2"):
model.text_encoder_2.resize_token_embeddings(model.smash_config.tokenizer.vocab_size)
12 changes: 12 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def get_diffusers_model(model_id: str, **kwargs: dict[str, Any]) -> tuple[Any, S
return model, smash_config


def get_diffusers_model_with_tokenizer(model_id: str, **kwargs: dict[str, Any]) -> tuple[Any, SmashConfig]:
"""Get a diffusers model for image generation."""
model, _ = get_diffusers_model(model_id, **kwargs)
smash_config = SmashConfig()
smash_config.add_data("LAION256")
smash_config.add_tokenizer("openai/clip-vit-base-patch32")
return model, smash_config


def get_automodel_transformers(model_id: str, **kwargs: dict[str, Any]) -> tuple[Any, SmashConfig]:
"""Get an AutoModelForCausalLM model for text generation."""
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
Expand Down Expand Up @@ -184,6 +193,9 @@ def get_autoregressive_text_to_image_model(model_id: str) -> tuple[Any, SmashCon
"sd_tiny_random": partial(get_diffusers_model, "dg845/tiny-random-stable-diffusion"),
"sana_tiny_random": partial(get_diffusers_model, "katuni4ka/tiny-random-sana"),
"flux_tiny_random": partial(get_diffusers_model, "katuni4ka/tiny-random-flux", torch_dtype=torch.bfloat16),
"flux_tiny_random_with_tokenizer": partial(
get_diffusers_model_with_tokenizer, "katuni4ka/tiny-random-flux", torch_dtype=torch.float16
),
# text generation models
"opt_tiny_random": partial(get_automodel_transformers, "yujiepan/opt-tiny-random"),
"smollm_135m": partial(get_automodel_transformers, "HuggingFaceTB/SmolLM2-135M"),
Expand Down