Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ dependencies = [
"aenum",
"vbench-pruna; sys_platform != 'darwin'",
"imageio-ffmpeg",
"jaxtyping"
"jaxtyping",
"peft>=0.18.0",
"trl<=0.21.0",
"termcolor==2.3.0",
]

[project.optional-dependencies]
Expand Down
8 changes: 8 additions & 0 deletions src/pruna/algorithms/base/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ class AlgorithmTag(Enum):
"batcher",
"Batching groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing overall processing time.",
)
RECOVERER = (
"recoverer",
"Recovery restores the performance of a model after compression.",
)
DISTILLER = (
"distiller",
"Distillation trains a smaller, simpler model to mimic a larger, more complex model.",
)

def __init__(self, name: str, description: str):
"""
Expand Down
73 changes: 73 additions & 0 deletions src/pruna/algorithms/distillation_perp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Iterable

from pruna.algorithms.base.tags import AlgorithmTag
from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer


class TextToImagePERPDistillation(PERPRecoverer):
"""
PERP distillation recoverer for text-to-image models.

This recoverer is a general purpose PERP recoverer for text-to-image models using norm and bias finetuning
as well as LoRA layers.

Parameters
----------
use_lora : bool
Whether to use LoRA adapters.
use_in_place : bool
Whether to use norm and bias finetuning which will modify the model in place.
"""

group_tags: list[AlgorithmTag] = [AlgorithmTag.DISTILLER, AlgorithmTag.RECOVERER] # type: ignore[attr-defined]
algorithm_name = "text_to_image_distillation_perp"
tokenizer_required = False
compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache"]
compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"]
runs_on: list[str] = ["cuda"]

def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None:
super().__init__(task_name="text_to_image", use_lora=use_lora, use_in_place=use_in_place, is_distillation=True)


class TextToImageInPlacePERPDistillation(TextToImagePERPDistillation):
"""
PERP distillation recoverer for text-to-image models without LoRA adapters.

This is the same as ``text_to_image_distillation_perp``, but without LoRA layers which add extra computations and
thus slow down the inference of the final model.
"""

algorithm_name = "text_to_image_distillation_inplace_perp"

def __init__(self) -> None:
super().__init__(use_lora=False, use_in_place=True)


class TextToImageLoraDistillation(TextToImagePERPDistillation):
"""
LoRA distillation recoverer for text-to-image models.

This recoverer attaches LoRA adapters to the model and uses them for distillation.
"""

algorithm_name = "text_to_image_distillation_lora"

def __init__(self) -> None:
super().__init__(use_lora=True, use_in_place=False)
13 changes: 13 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
101 changes: 101 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

import torch

from pruna.config.smash_config import SmashConfigPrefixWrapper


class PrunaAdapter(ABC):
"""Base class for adapters, defining which parameters to finetune for recovery."""

@property
@abstractmethod
def adapter_prefix(self) -> str:
"""The prefix of the adapter to use in the config."""
pass

@classmethod
@abstractmethod
def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
"""
Configure all algorithm-specific hyperparameters with ConfigSpace.

Parameters
----------
task_name : str
The name of the task, e.g. "text-to-image" or "text-to-text".
**override_defaults : Any
Values used to override the default hyperparameters when using multiple finetuners together.

Returns
-------
list
The hyperparameters.
"""
pass

@classmethod
@abstractmethod
def activate(
cls,
model: torch.nn.Module,
smash_config: SmashConfigPrefixWrapper,
seed: int | None = None,
) -> tuple[torch.nn.Module, int, int]:
"""
Activate or create the parameters in the model corresponding to the adapter.

Parameters
----------
model : torch.nn.Module
The model to apply the component to.
smash_config : SmashConfigPrefixWrapper
The configuration for the component.
seed : int
The seed to use for the adapter if it requires initialization.

Returns
-------
torch.nn.Module
The model with the adapter activated.
int
The number of trainable parameters.
int
The number of skipped parameters.
"""
pass

@classmethod
def pre_smash_hook(
cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int | None = None
) -> None:
"""
Optional hook to prepare the model/config before smashing.

Parameters
----------
model : torch.nn.Module
The model to prepare.
smash_config : SmashConfigPrefixWrapper
Configuration scoped to this adapter.
seed : int | None
Optional seed for deterministic initialization.
"""
pass
68 changes: 68 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/adapters/bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils


class BiasAdapter(PrunaAdapter):
"""Adapter for bias finetuning."""

adapter_prefix = "bias"

@classmethod
def get_hyperparameters(cls, *args, **kwargs) -> list:
"""
Configure all method-specific hyperparameters with ConfigSpace.

Parameters
----------
*args : Any
Unused arguments.
**kwargs : Any
Unused keyword arguments.

Returns
-------
list
The hyperparameters.
"""
return []

@classmethod
def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]:
"""
Activate all biases for training.

Parameters
----------
model : torch.nn.Module
The model containing the biases.
*args : Any
Unused additional arguments.
**kwargs : Any
Unused additional keyword arguments.

Returns
-------
torch.nn.Module
The model with the biases activated.
int
The number of trainable bias parameters.
int
The number of skipped bias parameters.
"""
num_activ_param, num_skip_param = utils.unfreeze_parameters_by_name(model, target_modules=("bias",))
return model, num_activ_param, num_skip_param
95 changes: 95 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/adapters/head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import torch

from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils
from pruna.logging.logger import pruna_logger


class HeadAdapter(PrunaAdapter):
"""Adapter for finetuning the model's head while keeping the backbone as is."""

adapter_prefix = "head"

@classmethod
def get_hyperparameters(cls, *args, **kwargs) -> list:
"""
Configure all method-specific hyperparameters with ConfigSpace.

Parameters
----------
*args : tuple
The arguments for the adapter.
**kwargs : dict
The hyperparameters for the adapter.

Returns
-------
list
The hyperparameters.
"""
return []

@classmethod
def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]:
"""
Activate the model's head for training.

Parameters
----------
model : torch.nn.Module
The model containing the head.
*args : tuple
The arguments for the adapter.
**kwargs : dict
The hyperparameters for the adapter.

Returns
-------
torch.nn.Module
The model with the head activated.
int
The number of trainable head parameters.
int
The number of skipped head parameters.
"""
# find head from type and name
model_heads = [
component
for comp_name, component in inspect.getmembers(model)
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
]
if len(model_heads) != 1:
# = 0: model with no head, e.g. diffusers
# > 1: model with multiple heads, e.g. for localization, not currently supported
model_head_names = [h[0] for h in model_heads] # type: ignore[index]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect indexing of Linear modules instead of names

Medium Severity

The code attempts [h[0] for h in model_heads] to extract head names, but model_heads is a list of torch.nn.Linear modules (not tuples). The list comprehension at lines 71-75 stores only component, discarding comp_name. When there are multiple heads (>1), indexing a Linear module with [0] will raise a TypeError. The fix requires storing (comp_name, component) tuples in model_heads or collecting names separately.

Fix in Cursor Fix in Web

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_head_names = [h[0] for h in model_heads] # type: ignore[index]
model_head_names = [
comp_name
for comp_name, component in inspect.getmembers(model)
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit hacky, but okay for me, or we also collect the name in model_heads in line 71, then we don't have to go through the model twice

pruna_logger.warning(
f"Found multiple heads but expected only one: {model_head_names}. Skipping head finetuning."
)
return model, 0, 0
model_head = model_heads[0]

# unfreeze head parameters, recording the number of trainable and skipped parameters
num_activ_param, num_skip_param = 0, 0
for param in model_head.parameters():
if utils.is_trainable(param):
param.requires_grad = True
num_activ_param += int(param.numel())
else:
num_skip_param += int(param.numel())

return model, num_activ_param, num_skip_param
Loading