From 8215794f0f3abdaf4890de6b1f13086fbe85e63f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 5 Jan 2026 09:08:32 +0000 Subject: [PATCH 1/4] added PERP recovery algorithms without distillion --- pyproject.toml | 5 +- src/pruna/algorithms/base/tags.py | 4 + .../global_utils/recovery/__init__.py | 13 + .../recovery/adapters/__init__.py | 101 ++++ .../global_utils/recovery/adapters/bias.py | 68 +++ .../global_utils/recovery/adapters/head.py | 95 +++ .../global_utils/recovery/adapters/lora.py | 313 ++++++++++ .../global_utils/recovery/adapters/norm.py | 85 +++ .../global_utils/recovery/adapters/utils.py | 178 ++++++ .../recovery/finetuners/__init__.py | 70 +++ .../diffusers/distillation_arg_utils.py | 180 ++++++ .../finetuners/diffusers/pack_and_predict.py | 164 ++++++ .../diffusers/scheduler_interface.py | 165 ++++++ .../recovery/finetuners/diffusers/utils.py | 192 ++++++ .../finetuners/text_to_image_finetuner.py | 557 ++++++++++++++++++ .../finetuners/text_to_text_finetuner.py | 256 ++++++++ .../global_utils/recovery/perp_recoverer.py | 321 ++++++++++ .../algorithms/global_utils/recovery/utils.py | 147 +++++ src/pruna/algorithms/perp.py | 110 ++++ tests/algorithms/testers/tti_inplace_perp.py | 48 ++ tests/algorithms/testers/tti_lora.py | 48 ++ tests/algorithms/testers/tti_perp.py | 46 ++ tests/algorithms/testers/ttt_inplace_perp.py | 45 ++ tests/algorithms/testers/ttt_lora.py | 45 ++ tests/algorithms/testers/ttt_perp.py | 45 ++ tests/algorithms/testers/utils.py | 10 + 26 files changed, 3310 insertions(+), 1 deletion(-) create mode 100644 src/pruna/algorithms/global_utils/recovery/__init__.py create mode 100644 src/pruna/algorithms/global_utils/recovery/adapters/__init__.py create mode 100644 src/pruna/algorithms/global_utils/recovery/adapters/bias.py create mode 100644 src/pruna/algorithms/global_utils/recovery/adapters/head.py create mode 100644 src/pruna/algorithms/global_utils/recovery/adapters/lora.py create mode 100644 src/pruna/algorithms/global_utils/recovery/adapters/norm.py create mode 100644 src/pruna/algorithms/global_utils/recovery/adapters/utils.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/distillation_arg_utils.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/pack_and_predict.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/scheduler_interface.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/utils.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/text_to_text_finetuner.py create mode 100644 src/pruna/algorithms/global_utils/recovery/perp_recoverer.py create mode 100644 src/pruna/algorithms/global_utils/recovery/utils.py create mode 100644 src/pruna/algorithms/perp.py create mode 100644 tests/algorithms/testers/tti_inplace_perp.py create mode 100644 tests/algorithms/testers/tti_lora.py create mode 100644 tests/algorithms/testers/tti_perp.py create mode 100644 tests/algorithms/testers/ttt_inplace_perp.py create mode 100644 tests/algorithms/testers/ttt_lora.py create mode 100644 tests/algorithms/testers/ttt_perp.py diff --git a/pyproject.toml b/pyproject.toml index 44df84bd..a919085f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,10 @@ dependencies = [ "aenum", "vbench-pruna; sys_platform != 'darwin'", "imageio-ffmpeg", - "jaxtyping" + "jaxtyping", + "peft>=0.18.0", + "trl<=0.21.0", + "termcolor==2.3.0", ] [project.optional-dependencies] diff --git a/src/pruna/algorithms/base/tags.py b/src/pruna/algorithms/base/tags.py index f2e37e6f..2995effb 100644 --- a/src/pruna/algorithms/base/tags.py +++ b/src/pruna/algorithms/base/tags.py @@ -64,6 +64,10 @@ class AlgorithmTag(Enum): "batcher", "Batching groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing overall processing time.", ) + RECOVERER = ( + "recoverer", + "Recovery restores the performance of a model after compression.", + ) def __init__(self, name: str, description: str): """ diff --git a/src/pruna/algorithms/global_utils/recovery/__init__.py b/src/pruna/algorithms/global_utils/recovery/__init__.py new file mode 100644 index 00000000..38e0d7e5 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/pruna/algorithms/global_utils/recovery/adapters/__init__.py b/src/pruna/algorithms/global_utils/recovery/adapters/__init__.py new file mode 100644 index 00000000..235e2079 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/adapters/__init__.py @@ -0,0 +1,101 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import torch + +from pruna.config.smash_config import SmashConfigPrefixWrapper + + +class PrunaAdapter(ABC): + """Base class for adapters, defining which parameters to finetune for recovery.""" + + @property + @abstractmethod + def adapter_prefix(self) -> str: + """The prefix of the adapter to use in the config.""" + pass + + @classmethod + @abstractmethod + def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + task_name : str + The name of the task, e.g. "text-to-image" or "text-to-text". + **override_defaults : Any + Values used to override the default hyperparameters when using multiple finetuners together. + + Returns + ------- + list + The hyperparameters. + """ + pass + + @classmethod + @abstractmethod + def activate( + cls, + model: torch.nn.Module, + smash_config: SmashConfigPrefixWrapper, + seed: int | None = None, + ) -> tuple[torch.nn.Module, int, int]: + """ + Activate or create the parameters in the model corresponding to the adapter. + + Parameters + ---------- + model : torch.nn.Module + The model to apply the component to. + smash_config : SmashConfigPrefixWrapper + The configuration for the component. + seed : int + The seed to use for the adapter if it requires initialization. + + Returns + ------- + torch.nn.Module + The model with the adapter activated. + int + The number of trainable parameters. + int + The number of skipped parameters. + """ + pass + + @classmethod + def pre_smash_hook( + cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int | None = None + ) -> None: + """ + Optional hook to prepare the model/config before smashing. + + Parameters + ---------- + model : torch.nn.Module + The model to prepare. + smash_config : SmashConfigPrefixWrapper + Configuration scoped to this adapter. + seed : int | None + Optional seed for deterministic initialization. + """ + pass diff --git a/src/pruna/algorithms/global_utils/recovery/adapters/bias.py b/src/pruna/algorithms/global_utils/recovery/adapters/bias.py new file mode 100644 index 00000000..d2c525db --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/adapters/bias.py @@ -0,0 +1,68 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils + + +class BiasAdapter(PrunaAdapter): + """Adapter for bias finetuning.""" + + adapter_prefix = "bias" + + @classmethod + def get_hyperparameters(cls, *args, **kwargs) -> list: + """ + Configure all method-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + *args : Any + Unused arguments. + **kwargs : Any + Unused keyword arguments. + + Returns + ------- + list + The hyperparameters. + """ + return [] + + @classmethod + def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]: + """ + Activate all biases for training. + + Parameters + ---------- + model : torch.nn.Module + The model containing the biases. + *args : Any + Unused additional arguments. + **kwargs : Any + Unused additional keyword arguments. + + Returns + ------- + torch.nn.Module + The model with the biases activated. + int + The number of trainable bias parameters. + int + The number of skipped bias parameters. + """ + num_activ_param, num_skip_param = utils.unfreeze_parameters_by_name(model, target_modules=("bias",)) + return model, num_activ_param, num_skip_param diff --git a/src/pruna/algorithms/global_utils/recovery/adapters/head.py b/src/pruna/algorithms/global_utils/recovery/adapters/head.py new file mode 100644 index 00000000..6bb38c31 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/adapters/head.py @@ -0,0 +1,95 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils +from pruna.logging.logger import pruna_logger + + +class HeadAdapter(PrunaAdapter): + """Adapter for finetuning the model's head while keeping the backbone as is.""" + + adapter_prefix = "head" + + @classmethod + def get_hyperparameters(cls, *args, **kwargs) -> list: + """ + Configure all method-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + *args : tuple + The arguments for the adapter. + **kwargs : dict + The hyperparameters for the adapter. + + Returns + ------- + list + The hyperparameters. + """ + return [] + + @classmethod + def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]: + """ + Activate the model's head for training. + + Parameters + ---------- + model : torch.nn.Module + The model containing the head. + *args : tuple + The arguments for the adapter. + **kwargs : dict + The hyperparameters for the adapter. + + Returns + ------- + torch.nn.Module + The model with the head activated. + int + The number of trainable head parameters. + int + The number of skipped head parameters. + """ + # find head from type and name + model_heads = [ + component + for comp_name, component in inspect.getmembers(model) + if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower() + ] + if len(model_heads) != 1: + # = 0: model with no head, e.g. diffusers + # > 1: model with multiple heads, e.g. for localization, not currently supported + model_head_names = [h[0] for h in model_heads] # type: ignore[index] + pruna_logger.warning( + f"Found multiple heads but expected only one: {model_head_names}. Skipping head finetuning." + ) + return model, 0, 0 + model_head = model_heads[0] + + # unfreeze head parameters, recording the number of trainable and skipped parameters + num_activ_param, num_skip_param = 0, 0 + for param in model_head.parameters(): + if utils.is_trainable(param): + param.requires_grad = True + num_activ_param += int(param.numel()) + else: + num_skip_param += int(param.numel()) + + return model, num_activ_param, num_skip_param diff --git a/src/pruna/algorithms/global_utils/recovery/adapters/lora.py b/src/pruna/algorithms/global_utils/recovery/adapters/lora.py new file mode 100644 index 00000000..b9e8b80b --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/adapters/lora.py @@ -0,0 +1,313 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch +from ConfigSpace import CategoricalHyperparameter, Constant, OrdinalHyperparameter +from diffusers.models import ( + FluxTransformer2DModel, + SanaTransformer2DModel, + UNet2DConditionModel, +) +from peft import LoraConfig, PeftMixedModel, PeftModel, get_peft_model +from pytorch_lightning import seed_everything +from pytorch_lightning.utilities.seed import isolate_rng + +from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import is_causal_lm +from pruna.engine.utils import get_device +from pruna.logging.logger import pruna_logger + + +class LoraAdapter(PrunaAdapter): + """Adapter for LoRA finetuning.""" + + adapter_prefix = "lora" + + @classmethod + def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + task_name : str + The name of the task, e.g. "text-to-image" or "text-to-text". + **override_defaults : Any + Values used to override the default hyperparameters when using multiple finetuners together. + + Returns + ------- + list + The hyperparameters. + """ + if task_name == "text_to_text": + # default values are based on + # https://github.com/huggingface/smollm/blob/6f2fbbb76f77c2f0db355a9d3cd2167ae2a11854/finetuning/train.py, + default_hyperparameters = { + "r": 8, + "alpha_r_ratio": 2.0, + "target_modules": None, # None is handled by the peft package in peft.utils.constants + # see TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING + "dropout": 0.05, + "variant": "lora", + } + default_hyperparameters.update(override_defaults) + return [ + OrdinalHyperparameter( + "r", + sequence=[4, 8, 16, 32, 64, 128], + default_value=default_hyperparameters["r"], + meta=dict(desc="Rank of the LoRA layers."), + ), + OrdinalHyperparameter( + "alpha_r_ratio", + sequence=[0.5, 1.0, 2.0], + default_value=default_hyperparameters["alpha_r_ratio"], + meta=dict(desc="Alpha/Rank ratio of the LoRA layers."), + ), + CategoricalHyperparameter( + "target_modules", + choices=[None, "all-linear"], + default_value=default_hyperparameters["target_modules"], + meta=dict(desc="Target modules for the LoRA layers."), + ), + Constant( + "dropout", + default_hyperparameters["dropout"], + meta=dict(desc="Dropout rate of the LoRA layers during training."), + ), + CategoricalHyperparameter( + "variant", + choices=["lora", "pissa"], + default_value=default_hyperparameters["variant"], + meta=dict(desc="Variant of the LoRA adapter."), + ), + ] + + elif task_name == "text_to_image": + # default values are based on + # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py, + default_hyperparameters = { + "r": 4, + "alpha_r_ratio": 1.0, + "target_modules": None, + "dropout": 0.0, + "variant": "lora", + } + default_hyperparameters.update(override_defaults) + return [ + OrdinalHyperparameter( + "r", + sequence=[4, 8, 16, 32, 64, 128], + default_value=default_hyperparameters["r"], + meta=dict(desc="Rank of the LoRA layers."), + ), + OrdinalHyperparameter( + "alpha_r_ratio", + sequence=[0.5, 1.0, 2.0], + default_value=default_hyperparameters["alpha_r_ratio"], + meta=dict(desc="Alpha/Rank ratio of the LoRA layers."), + ), + Constant( + "target_modules", default_hyperparameters["target_modules"] + ), # default choice depends on the model, allow user to choose in future + Constant("dropout", default_hyperparameters["dropout"]), + CategoricalHyperparameter( + "variant", + choices=["lora", "pissa"], + default_value=default_hyperparameters["variant"], + meta=dict(desc="Variant of the LoRA adapter."), + ), + ] + else: + raise ValueError(f"Task '{task_name}' is not yet supported for LoRA recovery.") + + @classmethod + def activate( + cls, + model: torch.nn.Module, + smash_config: SmashConfigPrefixWrapper, + seed: int | None = None, + ) -> tuple[torch.nn.Module, int, int]: + """ + Create LoRA layers. + + Parameters + ---------- + model : torch.nn.Module + The model to attach LoRA layers to. + smash_config : SmashConfigPrefixWrapper + The configuration to use for defining LoRA layers. + seed : int | None + The seed used to reproducibly initialize the LoRA layers. + + Returns + ------- + torch.nn.Module + The model with the LoRA layers activated. + int + The number of trainable LoRA parameters. + int + The number of skipped LoRA parameters. + """ + # save active parameters, device and dtype to restore after getting peft model + active_parameters = [param for param in model.parameters() if param.requires_grad] + device = get_device(model) + + if smash_config["variant"] == "lora": + # define LoRA layers + target_modules = smash_config["target_modules"] + if target_modules is None: + target_modules = cls.get_default_target_modules(model) + lora_config = LoraConfig( + r=smash_config["r"], + lora_alpha=int(smash_config["alpha_r_ratio"] * smash_config["r"]), + lora_dropout=float(smash_config["dropout"]), + target_modules=target_modules, + bias="none", + ) + model = _get_peft_model_with_seed(model, lora_config, seed=seed) + elif smash_config["variant"] == "pissa": + model = PeftModel.from_pretrained( + model, smash_config.cache_dir / "pissa_weights", is_trainable=True, torch_device=device + ) + else: + raise ValueError(f"Invalid LoRA variant: {smash_config['variant']}") + model.to(device=device) + + # count trainable LoRA parameters + num_lora_params = sum(p.numel() for name, p in model.named_parameters() if "lora" in name and p.requires_grad) + + # restore active parameters + for param in active_parameters: + param.requires_grad = True + return model, num_lora_params, 0 + + @classmethod + def pre_smash_hook( + cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int | None = None + ) -> None: + """ + Compute LoRA weights before smashing in case of a variant that requires the original weights. + + PiSSA initilization involves changing the base weights of the model, + which we want to apply adapters to during smashing. + Therefore, we need to apply the PiSSA adapter before smashing and save the adapter weights to a temporary file. + This file will be loaded during smashing and the adapter weights will be applied to the base weights. + + Parameters + ---------- + model : torch.nn.Module + The model to prepare. + smash_config : SmashConfigPrefixWrapper + The configuration to use for defining LoRA layers. + seed : int | None + The seed used to reproducibly initialize the LoRA layers. + """ + if smash_config["variant"] == "pissa": + pruna_logger.info("Performing pre-smash setup for PiSSA adapter.") + target_modules = smash_config["target_modules"] + if target_modules is None: + target_modules = cls.get_default_target_modules(model) + lora_config = LoraConfig( + r=smash_config["r"], + lora_alpha=int(smash_config["alpha_r_ratio"] * smash_config["r"]), + lora_dropout=float(smash_config["dropout"]), + target_modules=target_modules, + bias="none", + init_lora_weights="pissa_niter_4", # type: ignore[arg-type] + ) + model = _get_peft_model_with_seed(model, lora_config, seed=seed) + # reset LoRA initialization to default to avoid computing PiSSA weights a second time when loading + model.peft_config["default"].init_lora_weights = True # type: ignore[attr-defined] + pruna_logger.info(f"Saving PiSSA weights to {smash_config.cache_dir / 'pissa_weights'}") + model.save_pretrained(smash_config.cache_dir / "pissa_weights") + model.unload() + + @staticmethod + def get_default_target_modules(model: Any) -> list[str] | None: + """ + Return default target modules based on huggingface's finetuning scripts. + + Parameters + ---------- + model : Any + The model to get the default target modules for. + + Returns + ------- + list[str] | None + The default target modules. + """ + if is_causal_lm(model): + return None + elif isinstance(model, UNet2DConditionModel): # SD and SDXL + return ["to_k", "to_q", "to_v", "to_out.0"] + elif isinstance(model, SanaTransformer2DModel): # Sana + return ["to_k", "to_q", "to_v"] + elif isinstance(model, FluxTransformer2DModel): # Flux + return [ + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] + else: + pruna_logger.warning( + "Could not infer the target modules in the pipeline, " + "falling back to peft-defined defaults if peft recognizes the model architecture." + ) + # let the peft package handle the default target modules (see is_causal_lm case) + return None + + +def _get_peft_model_with_seed(*args: Any, seed: int | None = None, **kwargs: Any) -> PeftModel | PeftMixedModel: + """ + Call get_peft_model with a seed for reproducible initialization. + + Parameters + ---------- + *args + The arguments to pass to get_peft_model. + seed : int | None + The seed to use for the model. + **kwargs + The keyword arguments to pass to get_peft_model. + + Returns + ------- + PeftModel | PeftMixedModel + The peft model. + """ + if seed is None: + return get_peft_model(*args, **kwargs) + + with isolate_rng(): + seed_everything(seed, verbose=False) + model = get_peft_model(*args, **kwargs) # type: ignore[arg-type] + + return model diff --git a/src/pruna/algorithms/global_utils/recovery/adapters/norm.py b/src/pruna/algorithms/global_utils/recovery/adapters/norm.py new file mode 100644 index 00000000..f5c921c6 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/adapters/norm.py @@ -0,0 +1,85 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils + +# Normalization layers must be scraped by class name to match both torch, diffusers and other implementations. +# Matching by module names also does not work because it ends up matching e.g. AdaLayerNormZero which contain +# high dimensional linear layers. +NORM_CLASS_NAMES: tuple[str, ...] = ( + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "LayerNorm", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "RMSNorm", + "LlamaRMSNorm", +) + + +class NormAdapter(PrunaAdapter): + """Adapter for norm finetuning.""" + + adapter_prefix = "norm" + + @classmethod + def get_hyperparameters(cls, *args, **kwargs) -> list: + """ + Configure all method-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + *args : tuple + The arguments for the adapter. + **kwargs : dict + The keyword arguments for the adapter. + + Returns + ------- + list + The hyperparameters. + """ + return [] + + @classmethod + def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]: + """ + Activate all normalization layers for training. + + Parameters + ---------- + model : torch.nn.Module + The model containing the normalization layers. + *args : Any + Unused arguments. + **kwargs : Any + Unused keyword arguments. + + Returns + ------- + torch.nn.Module + The model with the normalization layers activated. + int + The number of trainable normalization parameters. + int + The number of skipped normalization parameters. + """ + num_activ_param, num_skip_param = utils.unfreeze_submodules_by_class_name(model, target_classes=NORM_CLASS_NAMES) + + return model, num_activ_param, num_skip_param diff --git a/src/pruna/algorithms/global_utils/recovery/adapters/utils.py b/src/pruna/algorithms/global_utils/recovery/adapters/utils.py new file mode 100644 index 00000000..37552c98 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/adapters/utils.py @@ -0,0 +1,178 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from torch import nn + +from pruna.engine.utils import get_nn_modules + + +def freeze_parameters(module: Any) -> None: + """ + Disable training for all parameters in the given module. + + Parameters + ---------- + module : torch.nn.Module + The module to freeze. + """ + for nn_module in get_nn_modules(module).values(): + for param in nn_module.parameters(): + param.requires_grad = False + + +def unfreeze_module(module: torch.nn.Module) -> None: + """ + Unfreeze all parameters of the given module. + + Parameters + ---------- + module : torch.nn.Module + The module to unfreeze. + """ + for param in module.parameters(): + param.requires_grad = True + + +def is_trainable(param: torch.nn.Parameter) -> bool: + """ + Check whether the parameter has been quantized, making it untrainable. + + Parameters + ---------- + param : torch.nn.Parameter + The parameter to check for quantization. + + Returns + ------- + bool + Whether the parameter has been quantized. + """ + # Note that quantized weights can be trained by accumulating gradients in a full-precision copy of the model. + # This function will be updated in the future when adding support for this mechanism. + return all( + [ + "float" in str(param.dtype), + not hasattr(param, "qtype"), # check quantization with quanto + ] + ) + + +def unfreeze_parameters_by_name(module: torch.nn.Module, target_modules: tuple[str]) -> tuple[int, int]: + """ + Unfreeze the parameters of the given module when their name contains any of the target_modules. + + Parameters + ---------- + module : torch.nn.Module + The module containing the parameters to unfreeze. + target_modules : tuple[str] + The names of the parameters, or modules containing the parameters to unfreeze. + + Returns + ------- + int + The number of parameters that were activated. + int + The number of parameters found that match the given name but were not trainable. + """ + if len(target_modules) == 0: + return + + activated_parameters, skipped_parameters = 0, 0 + for name, parameter in module.named_parameters(): + matches_name = any(substr in name for substr in target_modules) + if matches_name and is_trainable(parameter): + parameter.requires_grad = True + activated_parameters += int(parameter.numel()) + elif matches_name: # parameter has been quantized + skipped_parameters += int(parameter.numel()) + + return activated_parameters, skipped_parameters + + +def unfreeze_submodules_by_type( + module: torch.nn.Module, + target_types: tuple[type[nn.Module], ...], +) -> tuple[int, int]: + """ + Unfreeze the submodules of the given module when they are of the target type. + + Parameters + ---------- + module : torch.nn.Module + The module containing the submodules to unfreeze. + target_types : tuple[type] + The types identifying which submodules to unfreeze. + + Returns + ------- + int + The number of parameters that were activated. + int + The number of parameters found that match the given type but were not trainable. + """ + if len(target_types) == 0: + return 0, 0 + + activated_parameters, skipped_parameters = 0, 0 + for submodule in module.modules(): + if isinstance(submodule, target_types): + for param in submodule.parameters(): + if is_trainable(param): + param.requires_grad = True + activated_parameters += int(param.numel()) + else: # parameter has been quantized + skipped_parameters += int(param.numel()) + + return activated_parameters, skipped_parameters + + +def unfreeze_submodules_by_class_name( + module: torch.nn.Module, + target_classes: tuple[str, ...], +) -> tuple[int, int]: + """ + Unfreeze the submodules of the given module when their class name matches one of the target classes. + + Parameters + ---------- + module : torch.nn.Module + The module containing the submodules to unfreeze. + target_classes : tuple[str] + The class names identifying which submodules to unfreeze. + + Returns + ------- + int + The number of parameters that were activated. + int + The number of parameters found that match the given type but were not trainable. + """ + if len(target_classes) == 0: + return 0, 0 + + activated_parameters, skipped_parameters = 0, 0 + for submodule in module.modules(): + if submodule.__class__.__name__ in target_classes: + for param in submodule.parameters(): + if is_trainable(param): + param.requires_grad = True + activated_parameters += int(param.numel()) + else: # parameter has been quantized + skipped_parameters += int(param.numel()) + + return activated_parameters, skipped_parameters diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py b/src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py new file mode 100644 index 00000000..76028212 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py @@ -0,0 +1,70 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import torch + +from pruna.config.smash_config import SmashConfigPrefixWrapper + + +class PrunaFinetuner(ABC): + """Base class for recovery finetuners.""" + + @classmethod + @abstractmethod + def get_hyperparameters(cls, **override_defaults: Any) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + **override_defaults : Any + Values used to override the default hyperparameters when using multiple finetuners together. + + Returns + ------- + list + The hyperparameters. + """ + pass + + @classmethod + @abstractmethod + def finetune( + cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str + ) -> torch.nn.Module: + """ + Apply the component to the model: activate parameters for Adapters, or finetune them for Finetuners. + + Parameters + ---------- + model : torch.nn.Module + The model to apply the component to. + smash_config : SmashConfigPrefixWrapper + The configuration for the component. + seed : int + The seed to use for finetuning. + recoverer : str + The name of the recoverer used, for logging purposes. + + Returns + ------- + torch.nn.Module + The model with the component applied. + """ + pass diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/distillation_arg_utils.py b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/distillation_arg_utils.py new file mode 100644 index 00000000..dcf84b8a --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/distillation_arg_utils.py @@ -0,0 +1,180 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from functools import partial +from typing import Any, Callable, Dict, Tuple + +import torch +from peft import PeftModel + +from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr +from pruna.logging.logger import pruna_logger + + +def get_latent_replacement_fn(pipeline: Any) -> Callable: + """ + Get a function replacing the denoiser's latent argument with the recorded latent. + + Parameters + ---------- + pipeline : Any + The pipeline calling the denoiser. + + Returns + ------- + Callable + A function replacing the denoiser's latent argument with the recorded latent. + """ + expected_arg_name, expected_arg_idx = _get_expected_arg_name_and_idx(pipeline) + return partial(_replace_latent, expected_arg_idx, arg_name=expected_arg_name) + + +def get_latent_extractor_fn(pipeline: Any) -> Callable: + """ + Get a function extracting the latent from the pipeline's input arguments. + + Parameters + ---------- + pipeline : Any + The pipeline calling the denoiser. + + Returns + ------- + Callable + A function extracting the latent from the pipeline's input arguments. + """ + expected_arg_name, expected_arg_idx = _get_expected_arg_name_and_idx(pipeline) + return partial(_extract_latent, expected_arg_idx, arg_name=expected_arg_name) + + +def _get_expected_arg_name_and_idx(pipeline: Any) -> Tuple[str, int]: + """ + Get the expected argument name and index in the denoiser's forward method. + + This is used to generalize latent manipulation across different pipelines. + Some pipelines call their denoiser with latent as the first argument, others + with a name argument. This function returns both a name derived from the pipeline's + architecture, and the index, so both can be used depending on the situation. + + Parameters + ---------- + pipeline : Any + The pipeline calling the denoiser. + + Returns + ------- + Tuple[str, int] + A tuple of the expected argument name and index. + """ + denoiser, denoiser_attr = get_denoiser_attr(pipeline) + if denoiser is None: + raise ValueError("No denoiser attribute found in pipeline") + + if denoiser_attr == "unet": + expected_arg_name = "sample" + elif denoiser_attr == "transformer": + expected_arg_name = "hidden_states" + else: + raise ValueError(f"Unknown denoiser attribute: {denoiser_attr}") + + if isinstance(denoiser, PeftModel): + # PEFTModel does not wrap the forward method with the base model's signature + sig = inspect.signature(denoiser.model.forward) + else: + sig = inspect.signature(denoiser.forward) + + expected_arg_idx = next((i for i, p in enumerate(sig.parameters.keys()) if p == expected_arg_name), None) + if expected_arg_idx is None: + pruna_logger.error(f"Argument '{expected_arg_name}' not found in the denoiser's signature.") + raise ValueError(f"Argument '{expected_arg_name}' not found in the denoiser's signature") + else: + expected_arg_idx -= int("self" in sig.parameters) + + return expected_arg_name, expected_arg_idx + + +def _replace_latent( + arg_idx: int, latent: torch.Tensor, args: Tuple, kwargs: Dict[str, Any], arg_name: str | None = None +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """ + Replace the argument at the given index with the given name with the given value. + + Parameters + ---------- + arg_idx : int + The index of the argument to replace. + latent : torch.Tensor + The latent to replace the argument with. + args : Tuple + The arguments to the function. + arg_name : str | None + The name of the argument to replace. + kwargs : Dict[str, Any] + The keyword arguments to the function. + + Returns + ------- + Tuple[Tuple, Dict[str, Any]] + The arguments and keyword arguments to the function with the argument replaced. + """ + if arg_name is not None and arg_name in kwargs: + kwargs[arg_name] = latent + elif len(args) > arg_idx: + args = tuple(latent if i == arg_idx else x for i, x in enumerate(args)) + else: + raise ValueError( + "Argument mismatch when replacing denoiser arguments: " + f"attempted to replace {arg_name} at position {arg_idx}, " + f"but found {len(args)} positional arguments, and keyword arguments {list(kwargs.keys())}" + ) + return args, kwargs + + +def _extract_latent(arg_idx: int, args: Tuple, kwargs: Dict[str, Any], arg_name: str | None = None) -> torch.Tensor: + """ + Extract the latent from the pipeline's input arguments. + + Parameters + ---------- + arg_idx : int + The index of the argument to extract. + args : Tuple + The arguments to the function. + kwargs : Dict[str, Any] + The keyword arguments to the function. + arg_name : str | None + The name of the argument to extract. + + Returns + ------- + torch.Tensor + The latent extracted from the pipeline's input arguments. + """ + if arg_name is not None and arg_name in kwargs: + if not isinstance(kwargs[arg_name], torch.Tensor): + raise ValueError(f"Expected a tensor, got {type(kwargs[arg_name])}") + return kwargs[arg_name] + elif len(args) > arg_idx: + if not isinstance(args[arg_idx], torch.Tensor): + raise ValueError(f"Expected a tensor, got {type(args[arg_idx])}") + return args[arg_idx] + else: + raise ValueError( + f"Argument mismatch when extracting denoiser arguments: " + f"attempted to extract {arg_name} at position {arg_idx}, " + f"but found {len(args)} positional arguments, and keyword arguments {list(kwargs.keys())}" + ) diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/pack_and_predict.py b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/pack_and_predict.py new file mode 100644 index 00000000..adc2b186 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/pack_and_predict.py @@ -0,0 +1,164 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + +import torch + +from pruna.engine.model_checks import ( + is_flux_pipeline, + is_sana_pipeline, + is_sd_pipeline, + is_sdxl_pipeline, +) + + +def get_pack_and_predict_fn(pipeline: Any) -> Callable: + """ + Get a function to call the denoiser with consistent arguments. + + Parameters + ---------- + pipeline : Any + The pipeline to get the predict noise function from. + + Returns + ------- + Callable + The predict function, taking as arguments the denoiser, the noisy latents, the encoder hidden states, + and the timesteps. + """ + if is_sd_pipeline(pipeline): + return _pack_and_predict_stable_diffusion + elif is_sdxl_pipeline(pipeline): + return _pack_and_predict_stable_diffusion_xl + elif is_sana_pipeline(pipeline): + return _pack_and_predict_sana + elif is_flux_pipeline(pipeline): + return _pack_and_predict_flux + else: + raise ValueError(f"Unknown pipeline: {pipeline.__class__.__name__}") + + +def _pack_and_predict_stable_diffusion( + pipeline: Any, + noisy_latents: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timesteps: torch.Tensor, +) -> torch.Tensor: + """Format inputs for Stable Diffusion and apply unet.""" + prompt_embeds, _ = encoder_hidden_states + model_pred = pipeline.unet(noisy_latents, encoder_hidden_states=prompt_embeds, timestep=timesteps) + + return model_pred[0] + + +def _pack_and_predict_stable_diffusion_xl( + pipeline: Any, noisy_latents: torch.Tensor, encoder_hidden_states: torch.Tensor, timesteps: torch.Tensor +) -> torch.Tensor: + """Format inputs for Stable Diffusion XL and apply unet.""" + prompt_embeds, _, pooled_prompt_embeds, _ = encoder_hidden_states + + # Get resolution from the latents: + # Multiply by vae_scale_factor since latents are downsampled + height = noisy_latents.shape[2] * pipeline.vae_scale_factor + width = noisy_latents.shape[3] * pipeline.vae_scale_factor + + # Get text encoder projection dim from the pipeline + text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim + + add_time_ids = pipeline._get_add_time_ids( + original_size=(height, width), + crops_coords_top_left=(0, 0), + target_size=(height, width), + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.to(device=noisy_latents.device) + add_time_ids = add_time_ids.repeat(noisy_latents.shape[0], 1) + + added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} + + model_pred = pipeline.unet( + noisy_latents, + timesteps, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + return model_pred[0] + + +def _pack_and_predict_sana( + pipeline: Any, noisy_latents: torch.Tensor, encoder_hidden_states: torch.Tensor, timesteps: torch.Tensor +) -> torch.Tensor: + """Format inputs for Sana and apply transformer.""" + prompt_embeds, prompt_attention_mask, _, _ = encoder_hidden_states + return pipeline.transformer( + noisy_latents, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + )[0] + + +def _pack_and_predict_flux( + pipeline: Any, noisy_latents: torch.Tensor, encoder_hidden_states: torch.Tensor, timesteps: torch.Tensor +) -> torch.Tensor: + """Format inputs for Flux and apply transformer.""" + prompt_embeds, pooled_prompt_embeds, text_ids = encoder_hidden_states + latent_image_ids = pipeline._prepare_latent_image_ids( + noisy_latents.shape[0], + noisy_latents.shape[2] // 2, + noisy_latents.shape[3] // 2, + noisy_latents.device, + noisy_latents.dtype, + ) + packed_noisy_model_input = pipeline._pack_latents( + noisy_latents, + batch_size=noisy_latents.shape[0], + num_channels_latents=noisy_latents.shape[1], + height=noisy_latents.shape[2], + width=noisy_latents.shape[3], + ) + # guidance with 3.5 following default value in + # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py + if pipeline.transformer.config.guidance_embeds: + guidance = torch.tensor([3.5], device=noisy_latents.device) + guidance = guidance.expand(noisy_latents.shape[0]) + else: + guidance = None + + # predict + model_pred = pipeline.transformer( + hidden_states=packed_noisy_model_input, + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + # unpack for loss computation + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + model_pred = pipeline._unpack_latents( + model_pred, + height=noisy_latents.shape[2] * vae_scale_factor, + width=noisy_latents.shape[3] * vae_scale_factor, + vae_scale_factor=vae_scale_factor, + ) + + return model_pred diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/scheduler_interface.py b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/scheduler_interface.py new file mode 100644 index 00000000..e771e004 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/scheduler_interface.py @@ -0,0 +1,165 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch +from diffusers import DDPMScheduler, FlowMatchEulerDiscreteScheduler + +from pruna.logging.logger import pruna_logger + + +def get_training_scheduler(scheduler: Any) -> Any: + """ + Initialize a scheduler specifically for training to isolate finetuning from inference. + + Parameters + ---------- + scheduler : Any + The scheduler native to the pipeline. + + Returns + ------- + Any + A scheduler using during training, initialized from the pipeline's configuration. + If no training scheduler could be inferred, returns the native scheduler. + """ + prediction_type = get_prediction_type(scheduler) + if prediction_type == "flow_prediction": + # Sana and Flux schedulers + return FlowMatchEulerDiscreteScheduler.from_config(scheduler.config) + elif prediction_type in ["epsilon", "v_prediction"]: + # DDPM and DDPMSolverMultistepScheduler + return DDPMScheduler.from_config(scheduler.config) + else: + pruna_logger.warning( + f"Could not infer a scheduler for finetuning {scheduler.__class__.__name__}, " + "defaulting to the native scheduler, which may cause numerical issues." + ) + return scheduler + + +def sample_timesteps(training_scheduler: Any, batch_size: int, device: str | torch.device) -> torch.Tensor: + """ + Sample timesteps for the scheduler. + + Parameters + ---------- + training_scheduler : Any + The scheduler used during training. + batch_size : int + The batch size. + device : str | torch.device + The device to sample the timesteps on. + + Returns + ------- + torch.Tensor + The sampled timesteps, with shape (batch_size,). + """ + # uniform timesteps for simplicity, replaced with compute_density_for_timestep_sampling in future + indices = torch.randint(0, training_scheduler.config.num_train_timesteps, (batch_size,)).long() + if hasattr(training_scheduler, "timesteps") and training_scheduler.timesteps is not None: + timesteps = training_scheduler.timesteps[indices] + else: + # same distribution as indices but count backwards for consistency with what simple timesteps[indices] does + timesteps = training_scheduler.config.num_train_timesteps - 1 - indices + return timesteps.to(device=device) + + +def add_noise( + training_scheduler: Any, latents: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor +) -> torch.Tensor: + """ + Add noise to the latents using the training scheduler. + + Parameters + ---------- + training_scheduler : Any + The scheduler used during training. + latents : torch.Tensor + The latents to add noise to. + noise : torch.Tensor + The noise to add, with the same shape as `latents`. + timesteps : torch.Tensor + The timesteps used to compute how much noise to add. + + Returns + ------- + torch.Tensor + The noisy latents. + """ + if hasattr(training_scheduler, "add_noise"): + return training_scheduler.add_noise(latents, noise, timesteps) + elif hasattr(training_scheduler, "scale_noise"): + return training_scheduler.scale_noise(latents, timesteps, noise) + else: + raise ValueError("Unknown method for adding noise to latents") + + +def get_target( + training_scheduler: Any, latents: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor +) -> torch.Tensor: + """ + Define the target used to finetune the denoiser. + + Parameters + ---------- + training_scheduler : Any + The scheduler used during training. + latents : torch.Tensor + The latents to add noise to. + noise : torch.Tensor + The noise to add, with the same shape as `latents`. + timesteps : torch.Tensor + The used for the training step and corresponding to the noise levels. + + Returns + ------- + torch.Tensor + The target used to finetune the denoising process. + """ + prediction_type = get_prediction_type(training_scheduler) + if prediction_type == "epsilon": + return noise + elif prediction_type == "v_prediction": + return training_scheduler.get_velocity(latents, noise, timesteps) + elif prediction_type == "flow_prediction": + # Sana and Flux schedulers + return noise - latents + else: + raise ValueError(f"Unknown prediction type or scheduler {prediction_type}") + + +def get_prediction_type(scheduler: Any) -> str: + """ + Get the prediction type from the scheduler. + + Parameters + ---------- + scheduler : Any + The scheduler to get the prediction type from. + + Returns + ------- + str + The prediction type, in ['epsilon', 'v_prediction', 'flow_prediction']. + """ + prediction_type = getattr(scheduler.config, "prediction_type", scheduler.config._class_name) + if prediction_type == "flow_prediction" or "flowmatch" in prediction_type.lower(): + return "flow_prediction" # e.g. Sana and Flux schedulers + else: + return prediction_type diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/utils.py b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/utils.py new file mode 100644 index 00000000..a79cc308 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/diffusers/utils.py @@ -0,0 +1,192 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Any, Dict, Tuple, Union, get_args, get_origin + +import torch +from diffusers.utils import BaseOutput + +from pruna.engine.model_checks import ( + is_flux_pipeline, + is_sana_pipeline, + is_sd_pipeline, + is_sdxl_pipeline, +) +from pruna.logging.logger import pruna_logger + + +def get_denoiser_attr(pipeline: Any) -> Tuple[Any | None, str]: + """ + Get the denoiser attribute in a pipeline and its name. + + Parameters + ---------- + pipeline : Any + The pipeline to get the denoiser attribute from. + + Returns + ------- + Tuple[Any | None, str] + The denoiser attribute and its name. If no attribute is found, return None and an empty string. + """ + possible_names = ["unet", "transformer"] + for name in possible_names: + if hasattr(pipeline, name): + return getattr(pipeline, name), name + return None, "" + + +def set_denoiser_attr(pipeline: Any, denoiser: torch.nn.Module) -> None: + """ + Set the denoiser attribute in a pipeline. + + Parameters + ---------- + pipeline : Any + The pipeline to set the denoiser attribute in. + denoiser : torch.nn.Module + The denoiser to set in the pipeline. + """ + possible_names = ["unet", "transformer"] + for name in possible_names: + if hasattr(pipeline, name): + setattr(pipeline, name, denoiser) + return + raise ValueError(f"Unknown pipeline: {pipeline.__class__.__name__}") + + +def uses_prompt_2(pipeline: Any) -> bool: + """ + Check if the pipeline uses a second prompt. + + Parameters + ---------- + pipeline : Any + The pipeline to check. + + Returns + ------- + bool + True if the pipeline uses a second prompt, False otherwise. + """ + return is_flux_pipeline(pipeline) + + +def get_encode_arguments(pipeline: Any) -> Dict[str, Any]: + """ + Get arguments specific to the encode_prompt function of each pipeline type. + + Parameters + ---------- + pipeline : Any + The pipeline to get the encode_prompt method from. + + Returns + ------- + Dict[str, Any] + Keyword arguments to pass to the encode_prompt function. + """ + if is_sd_pipeline(pipeline) or is_sdxl_pipeline(pipeline): + return dict(do_classifier_free_guidance=True) + elif is_sana_pipeline(pipeline) or is_flux_pipeline(pipeline): + return dict() + else: + raise ValueError(f"Unknown pipeline: {pipeline.__class__.__name__}") + + +def move_secondary_components(pipeline: Any, device: str | torch.device) -> None: + """ + Move the secondary components of the pipeline to the device. + + Parameters + ---------- + pipeline : Any + The pipeline to move the secondary components to. + device : str | torch.device + The device to move the secondary components to. + """ + if hasattr(pipeline, "text_encoder"): + pipeline.text_encoder.to(device=device) + if hasattr(pipeline, "text_encoder_2"): + pipeline.text_encoder_2.to(device=device) + + +def check_resolution_mismatch(pipeline: Any, dataloader: torch.utils.data.DataLoader) -> None: + """ + Log a warning if there's a mismatch between the dataloader image resolution and the pipeline's configured resolution. + + Parameters + ---------- + pipeline : Any + The pipeline. + dataloader : torch.utils.data.DataLoader + The dataloader containing the training images. + """ + # Get first batch to check image size + first_batch = next(iter(dataloader)) + images = first_batch[1] # (batch_size, channels, height, width) + image_height, image_width = images.shape[-2:] + + if is_sd_pipeline(pipeline) or is_sdxl_pipeline(pipeline): + config_size = pipeline.unet.config.sample_size * pipeline.vae_scale_factor + elif is_sana_pipeline(pipeline): + config_size = pipeline.transformer.config.sample_size * pipeline.vae_scale_factor + elif is_flux_pipeline(pipeline): + config_size = pipeline.vae.config.sample_size + else: + pruna_logger.warning( + f"Unknown pipeline: {pipeline.__class__.__name__}, please make sure the image resolution matches " + "the pipeline's configured resolution for finetuning as it might affect recovery performance." + ) + return + + if image_height != config_size or image_width != config_size: + pruna_logger.warning( + f"The resolution of the provided dataset ({image_height}x{image_width}) differs from " + f"the pipeline's configured resolution ({config_size}x{config_size}). " + "This might affect recovery performance." + ) + + +def get_denoiser_output_class(denoiser: Any) -> type[BaseOutput]: + """ + Get the denoiser output class for a denoiser by inspecting its forward method signature. + + Parameters + ---------- + denoiser : Any + The denoiser whose output class to get. + + Returns + ------- + type[BaseOutput] | None + The denoiser output class if found, otherwise None. + """ + # Get the forward method signature + signature = inspect.signature(denoiser.forward) + output_type = signature.return_annotation + + # Extract different types from Union + output_types = get_args(output_type) if get_origin(output_type) is Union else [output_type] + base_output_types = [t for t in output_types if inspect.isclass(t) and issubclass(t, BaseOutput)] + + if len(base_output_types) == 1: + return base_output_types[0] + else: + raise ValueError( + f"Could not infer the denoiser's return type, expected exactly one BaseOutput type, got {output_type}" + ) diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py new file mode 100644 index 00000000..935468b8 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py @@ -0,0 +1,557 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import Any, List, Literal, Tuple + +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.seed import isolate_rng + +try: + from bitsandbytes.optim import AdamW8bit # type: ignore[import-untyped] +except ImportError: + AdamW8bit = None + +from ConfigSpace import ( + CategoricalHyperparameter, + Constant, + UniformFloatHyperparameter, + UniformIntegerHyperparameter, +) + +from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner +from pruna.algorithms.global_utils.recovery.finetuners.diffusers import ( + pack_and_predict, + scheduler_interface, + utils, +) +from pruna.algorithms.global_utils.recovery.utils import ( + get_dtype, + get_trainable_parameters, + split_defaults, +) +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.logging.logger import pruna_logger + + +class TextToImageFinetuner(PrunaFinetuner): + """Finetuner for text-to-image models.""" + + @classmethod + def get_hyperparameters(cls, **override_defaults) -> List: + """ + Configure all method-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + **override_defaults : dict + The hyperparameters to override. + + Returns + ------- + list + The hyperparameters. + """ + defaults = { + "training_batch_size": 0, # 0: the default smash_config.train_dataloader's batch size is used + "gradient_accumulation_steps": 1, + "num_epochs": 1.0, + "validate_every_n_epoch": 1.0, + "learning_rate": 1e-4, + "weight_decay": 1e-2, + "report_to": "none", + "optimizer": "AdamW8bit" if torch.cuda.is_available() else "AdamW", # AdamW8bit from BnB assumes CUDA + } + defaults.update(override_defaults) + string_defaults, numeric_defaults = split_defaults(defaults) + + return [ + UniformIntegerHyperparameter( + "training_batch_size", + lower=0, + upper=4096, + default_value=numeric_defaults["training_batch_size"], + meta=dict(desc="Batch size for finetuning."), + ), + UniformIntegerHyperparameter( + "gradient_accumulation_steps", + lower=1, + upper=1024, + default_value=numeric_defaults["gradient_accumulation_steps"], + meta=dict(desc="Number of gradient accumulation steps for finetuning."), + ), + UniformIntegerHyperparameter( + "num_epochs", + lower=0, + upper=4096, + default_value=numeric_defaults["num_epochs"], + meta=dict(desc="Number of epochs for finetuning."), + ), + UniformFloatHyperparameter( + "validate_every_n_epoch", + lower=0.0, + upper=4096.0, + default_value=numeric_defaults["validate_every_n_epoch"], + meta=dict( + desc="Number of epochs between each round of validation and model checkpointing. " + "If the value is between 0 and 1, validation will be performed multiple times per epoch, " + "e.g. 1/8 will result in 8 validations per epoch." + ), + ), + UniformFloatHyperparameter( + "learning_rate", + lower=0.0, + upper=1.0, + default_value=numeric_defaults["learning_rate"], + meta=dict(desc="Learning rate for finetuning."), + ), + Constant("weight_decay", numeric_defaults["weight_decay"]), + # report_to: for consistency with text-to-text-lora but wandb and tensorboard are not supported yet + Constant("report_to", string_defaults["report_to"]), + Boolean( + "use_cpu_offloading", + default=False, + meta=dict(desc="Whether to use CPU offloading for finetuning."), + ), # necessary for Flux in float16 on L40S GPU (48gb VRAM) + CategoricalHyperparameter( + "optimizer", + choices=["AdamW8bit", "AdamW", "Adam"], + default_value=string_defaults["optimizer"], + meta=dict(desc="Which optimizer to use for finetuning."), + ), + ] + + @classmethod + def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any: + """ + Finetune the model's previously activated parameters on data. + + This function is adapted from the HuggingFace implementation of the finetuning process at + https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py, + https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sana.py, + and https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py + + Parameters + ---------- + pipeline : Any + The pipeline containing components to finetune. + smash_config : SmashConfigPrefixWrapper + The configuration for the finetuner. + seed : int + The seed to use for reproducibility. + recoverer : str + The name of the algorithm used for finetuning, used for logging. + + Returns + ------- + Any + The finetuned pipeline. + """ + dtype = get_dtype(pipeline) + device = smash_config.device if isinstance(smash_config.device, str) else smash_config.device.type + + # split seed into two rng: generator for the dataloader and a seed for the training part + generator = torch.Generator().manual_seed(seed) + fit_seed = int(torch.randint(0, 2**32 - 1, (1,), generator=generator).item()) + + # Dataloaders + # override batch size if user specified one for finetuning specifically + batch_size: int + if smash_config["training_batch_size"] > 0 and smash_config.is_batch_size_locked(): + pruna_logger.warning( + "Batch size is locked by a previous smashing algorithm, " + "ignoring user-specified batch size for finetuning." + ) + batch_size = smash_config.batch_size + elif smash_config["training_batch_size"] > 0: + batch_size = smash_config["training_batch_size"] + else: + batch_size = smash_config.batch_size + # train dataloader has a generator for reproducibility, val is not shuffled so it doesn't need one + train_dataloader = smash_config.data.train_dataloader( + batch_size=batch_size, + output_format="normalized", + generator=generator, + ) + val_dataloader = smash_config.data.val_dataloader( + batch_size=batch_size, + output_format="normalized", + ) + + optimizer_name = smash_config["optimizer"] + if optimizer_name == "AdamW8bit" and device != "cuda": + pruna_logger.warning( + "Optimizer AdamW8bit from bitsandbytes requires CUDA, continuing with AdamW from torch." + ) + optimizer_name = "AdamW" + + # Check resolution mismatch + utils.check_resolution_mismatch(pipeline, train_dataloader) + + # Finetune the model + trainable_denoiser = DenoiserTL( + pipeline, + optimizer_name, + smash_config["learning_rate"], + smash_config["weight_decay"], + recoverer, + smash_config["use_cpu_offloading"], + ) + # make directory for checkpoints + model_path = Path(smash_config.cache_dir) / "recovery" + model_path.mkdir(exist_ok=True, parents=True) + + early_stopping = EarlyStopping(monitor="validation_loss", patience=3, mode="min", check_finite=True) + checkpoint_callback = ModelCheckpoint( + monitor="validation_loss", + save_top_k=1, + mode="min", + every_n_epochs=1, + dirpath=model_path, + filename="{recoverer}-{epoch:02d}-{validation_loss:.4f}", + ) + callbacks = [early_stopping, checkpoint_callback] + + if smash_config["validate_every_n_epoch"] >= 1.0: + # set TL trainer to perform validation every n epochs, for small datasets + check_val_every_n_epoch = int(smash_config["validate_every_n_epoch"]) + val_check_interval = None + else: + # set TL trainer to perform validation multiple times per epoch, for large datasets + check_val_every_n_epoch = 1 + val_check_interval = smash_config["validate_every_n_epoch"] + + precision: Literal["16-true", "bf16-true", "32"] + if dtype == torch.float16: + precision = "16-true" + elif dtype == torch.bfloat16: + precision = "bf16-true" + else: + precision = "32" + + trainer = pl.Trainer( + accelerator=device, + callbacks=callbacks, + max_epochs=smash_config["num_epochs"], + gradient_clip_val=1.0, + precision=precision, + accumulate_grad_batches=smash_config["gradient_accumulation_steps"], + check_val_every_n_epoch=check_val_every_n_epoch, + val_check_interval=val_check_interval, + logger=False, + ) + with isolate_rng(): + pl.seed_everything(fit_seed, workers=True, verbose=False) + trainer.validate(trainable_denoiser, dataloaders=val_dataloader, verbose=False) + trainer.fit(trainable_denoiser, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) + + # Loading the best checkpoint is slow and currently creates conflicts with some quantization algorithms, + # e.g. diffusers_int8. Skipping calling DenoiserTL.load_from_checkpoint for now. + return pipeline.to(device) + + +class DenoiserTL(pl.LightningModule): + """ + Pipeline in LightningModule format for finetuning the denoiser. + + Parameters + ---------- + pipeline : Any + The pipeline to finetune. + optimizer_name : str + The name of the optimizer to use, options are "AdamW8bit", or optimizers from torch.optim. + learning_rate : float + The learning rate to use for finetuning. + weight_decay : float + The weight decay to use for finetuning. + recoverer : str + The name of the algorithm used for finetuning, used for logging. + use_cpu_offloading : bool, optional + Whether to use CPU offloading for finetuning. + validation_seed : int, optional + The seed to use for validation, used to reproducibly generate the random noise and timesteps for the + validation set. + """ + + def __init__( + self, + pipeline: Any, + optimizer_name: str, + learning_rate: float, + weight_decay: float, + recoverer: str, + use_cpu_offloading: bool = False, + validation_seed: int = 42, + ): + super().__init__() + self.pipeline = pipeline + self.optimizer_name = optimizer_name + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.use_cpu_offloading = use_cpu_offloading + self.validation_generator = torch.Generator().manual_seed(validation_seed) + self.validation_seeds: List[int] = [] + # register the denoiser explicitly to work with self.parameters() and other LightningModule methods + denoiser, _ = utils.get_denoiser_attr(pipeline) + if denoiser is None or not isinstance(denoiser, torch.nn.Module): + raise ValueError("Could not find the denoiser in the pipeline.") + self.denoiser = denoiser + + self.pack_and_predict = pack_and_predict.get_pack_and_predict_fn(pipeline) + self.uses_prompt_2 = utils.uses_prompt_2(pipeline) + self.encode_arguments = utils.get_encode_arguments(pipeline) + self.training_scheduler = scheduler_interface.get_training_scheduler(pipeline.scheduler) + + # basic logging + self.recoverer = recoverer + self.val_losses: List[torch.Tensor] = [] + self.train_losses: List[torch.Tensor] = [] + + def forward( + self, noisy_latents: torch.Tensor, encoder_hidden_states: torch.Tensor, timesteps: torch.Tensor + ) -> torch.Tensor: + """ + Forward pass of the denoiser. + + Parameters + ---------- + noisy_latents : torch.Tensor + The noisy latents to denoise. + encoder_hidden_states : torch.Tensor + The encoder hidden states. + timesteps : torch.Tensor + The timesteps used for positional encoding. + + Returns + ------- + torch.Tensor + The denoised latents. + """ + return self.pack_and_predict(self.pipeline, noisy_latents, encoder_hidden_states, timesteps) + + def prepare_latents_targets( + self, images: torch.Tensor, timesteps: torch.Tensor, generator: torch.Generator | None = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare the latents and targets for the denoiser. + + Parameters + ---------- + images : torch.Tensor + The images to prepare. + timesteps : torch.Tensor + The timesteps to use for the denoiser. + generator : torch.Generator | None, optional + The generator to use for drawing random noise. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + The noisy latents and the targets. + """ + dtype = self.pipeline.vae.dtype + + # Convert images to latent space + latents = self.pipeline.vae.encode(images.to(dtype)) + # applying vae is fine with Sana in 32 bit and 16 bit (model in 16bit w/o, both with no info for the trainer and + # with 16-true), but produces NaNs in 16-mixed + latents = latents.latent_dist.sample() if hasattr(latents, "latent_dist") else latents.latent + # Apply shift and scaling factors + if hasattr(self.pipeline.vae.config, "shift_factor") and self.pipeline.vae.config.shift_factor is not None: + latents = latents - self.pipeline.vae.config.shift_factor + latents = latents * self.pipeline.vae.config.scaling_factor + + # Add noise to the latents according to the noise magnitude at each timestep + # torch.randn_like does not support generators + noise = torch.randn(latents.shape, dtype=latents.dtype, device=latents.device, generator=generator) + noisy_latents = scheduler_interface.add_noise(self.training_scheduler, latents, noise, timesteps) + + # define the corresponding target for training + target = scheduler_interface.get_target(self.training_scheduler, latents, noise, timesteps) + + return noisy_latents, target + + def encode_prompt(self, captions: List[str]) -> torch.Tensor: + """ + Encode the prompts. + + Parameters + ---------- + captions : list[str] + The captions to encode. + + Returns + ------- + torch.Tensor + The encoded prompts. + """ + prompt_args = [captions] + if self.uses_prompt_2: + prompt_args = prompt_args * 2 + encoder_hidden_states = self.pipeline.encode_prompt( + *prompt_args, + device=self.device, + num_images_per_prompt=1, + **self.encode_arguments, + ) + return encoder_hidden_states + + def training_step(self, batch: Tuple[List[str], torch.Tensor], batch_idx: int): + """ + Training step of the denoiser. + + Parameters + ---------- + batch : tuple[list[str], torch.Tensor] + The batch of (captions, images). + batch_idx : int + The index of the batch. + + Returns + ------- + torch.Tensor + The MSE loss between the predicted and target latents. + """ + captions, images = batch + batch_size = int(images.shape[0]) + + if self.use_cpu_offloading: + # make sure secondary components are on the same device as the model + utils.move_secondary_components(self.pipeline, self.device) + + # uniform timesteps for simplicity, replace with compute_density_for_timestep_sampling in future + timesteps = scheduler_interface.sample_timesteps(self.training_scheduler, batch_size, self.device) + noisy_latents, target = self.prepare_latents_targets(images, timesteps) + encoder_hidden_states = self.encode_prompt(captions) + + if self.use_cpu_offloading: + # clear memory: required in testing to fit flux bfloat16 + lora finetuning with bs=1 onto 48gb VRAM + utils.move_secondary_components(self.pipeline, "cpu") + + model_pred = self.forward(noisy_latents, encoder_hidden_states, timesteps) + loss = self.loss(model_pred, target) + + self.log("train_loss", loss) + self.train_losses.append(loss) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + """ + Validation step of the denoiser. + + Parameters + ---------- + batch : tuple[list[str], torch.Tensor] + The batch of (captions, images). + batch_idx : int + The index of the batch. + + Returns + ------- + torch.Tensor + The MSE loss between the predicted and target latents. + """ + captions, images = batch + batch_size = int(images.shape[0]) + + if self.use_cpu_offloading: + # make sure secondary components are on the same device as the model + utils.move_secondary_components(self.pipeline, self.device) + + # get seeds for reproducibility + # - make sure the timesteps and random noise are the same across validations + # - make sure the seeds are different across batches + if len(self.validation_seeds) <= batch_idx: + seed = int(torch.randint(0, 2**32 - 1, (1,), generator=self.validation_generator).item()) + self.validation_seeds.append(seed) + else: + seed = self.validation_seeds[batch_idx] + + with isolate_rng(): + pl.seed_everything(seed, verbose=False) + + timesteps = scheduler_interface.sample_timesteps(self.training_scheduler, batch_size, self.device) + noisy_latents, target = self.prepare_latents_targets(images, timesteps) + encoder_hidden_states = self.encode_prompt(captions) + + # no need to do CPU offloading in evaluation since gradients are not computed + model_pred = self.forward(noisy_latents, encoder_hidden_states, timesteps) + loss = self.loss(model_pred, target) + + self.log("validation_loss", loss) + self.val_losses.append(loss) + return {"loss": loss} + + def loss(self, model_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute the denoising loss. + + Parameters + ---------- + model_pred : torch.Tensor + The predicted latents. + target : torch.Tensor + The target latents. + + Returns + ------- + torch.Tensor + The MSE loss between the predicted and target latents. + """ + return torch.nn.functional.mse_loss(model_pred, target) + + def on_train_epoch_end(self): + """Log the train loss.""" + loss = torch.stack(self.train_losses).mean() + pruna_logger.info(f"{self.recoverer} - epoch {self.current_epoch} - train loss: {loss:.3e}") + self.train_losses.clear() + + def on_validation_epoch_end(self): + """Log the validation loss.""" + if self.trainer.sanity_checking: + return + loss = torch.stack(self.val_losses).mean() + pruna_logger.info(f"{self.recoverer} - epoch {self.current_epoch} - validation loss: {loss:.3e}") + self.val_losses.clear() + + def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Configure the optimizer. + + Returns + ------- + torch.optim.Optimizer + The optimizer. + """ + lr = self.learning_rate + wd = self.weight_decay + kwargs = {"eps": 1e-7} if self.trainer.precision in [16, "16", "16-true"] else {} + + if self.optimizer_name == "AdamW8bit": + optimizer_cls = AdamW8bit + if optimizer_cls is None: + pruna_logger.warning( + "Recovery with AdamW8bit requires bitsandbytes to be installed, continuing with AdamW from torch." + ) + optimizer_cls = torch.optim.AdamW + else: + optimizer_cls = getattr(torch.optim, self.optimizer_name) + finetune_params = get_trainable_parameters(self.pipeline) + + return optimizer_cls(finetune_params, lr=lr, weight_decay=wd, **kwargs) diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_text_finetuner.py b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_text_finetuner.py new file mode 100644 index 00000000..1b126c78 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_text_finetuner.py @@ -0,0 +1,256 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +from typing import Iterator, Optional + +import torch +from ConfigSpace import ( + CategoricalHyperparameter, + Constant, + UniformFloatHyperparameter, + UniformIntegerHyperparameter, +) +from datasets import Dataset +from trl import SFTConfig, SFTTrainer + +from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner +from pruna.algorithms.global_utils.recovery.utils import get_dtype, split_defaults +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.logging.logger import pruna_logger + + +class TextToTextFinetuner(PrunaFinetuner): + """Finetuner for text-to-text models.""" + + @classmethod + def get_hyperparameters(cls, **override_defaults) -> list: + """ + Configure all method-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + **override_defaults : dict + The hyperparameters to override. + + Returns + ------- + list + The hyperparameters. + """ + defaults: dict[str, int | float | str] = { + "training_batch_size": 1, + "gradient_accumulation_steps": 1, + "num_epochs": 1.0, + "learning_rate": 2e-4, + "dataset_text_field": "text", + "report_to": "none", + "optimizer": "AdamW8bit", + } + defaults.update(override_defaults) + string_defaults, numeric_defaults = split_defaults(defaults) + return [ + UniformIntegerHyperparameter( + "training_batch_size", + lower=1, + upper=4096, + default_value=numeric_defaults["training_batch_size"], + meta=dict(desc="Batch size for finetuning."), + ), + UniformIntegerHyperparameter( + "gradient_accumulation_steps", + lower=1, + upper=1024, + default_value=numeric_defaults["gradient_accumulation_steps"], + meta=dict(desc="Number of gradient accumulation steps for finetuning."), + ), + UniformFloatHyperparameter( + "num_epochs", + lower=0.0, + upper=4096.0, + default_value=numeric_defaults["num_epochs"], + meta=dict(desc="Number of epochs for finetuning."), + ), + UniformFloatHyperparameter( + "learning_rate", + lower=0.0, + upper=1.0, + default_value=numeric_defaults["learning_rate"], + meta=dict(desc="Learning rate for finetuning."), + ), + Constant("dataset_text_field", string_defaults["dataset_text_field"]), + CategoricalHyperparameter( + "report_to", + choices=["none", "wandb", "tensorboard"], + default_value=string_defaults["report_to"], + meta=dict(desc="Where to report the finetuning results."), + ), + CategoricalHyperparameter( + "optimizer", + choices=["AdamW", "AdamW8bit", "PagedAdamW8bit"], + default_value=string_defaults["optimizer"], + meta=dict(desc="Which optimizer to use for finetuning."), + ), + ] + + @classmethod + def finetune( + cls, + model: torch.nn.Module, + smash_config: SmashConfigPrefixWrapper, + seed: int, + recoverer: str, + report_every_n_samples: int | None = None, + ) -> torch.nn.Module: + """ + Finetune the model's previously activated parameters. + + Parameters + ---------- + model : torch.nn.Module + The model to apply the finetuner to. + smash_config : SmashConfigPrefixWrapper + The configuration for the finetuner. + seed : int + The seed to use for reproducibility. + recoverer : str + The recoverer used, i.e. the selection of parameters to finetune. This is only used for logging purposes. + report_every_n_samples : int | None, optional + The number of samples between reports to the logger. + If None, the number of samples between reports is set to 1/8 of the dataset size. + + Returns + ------- + torch.nn.Module + The finetuned model. + """ + dtype = get_dtype(model) + + # format dataset + dataset, dataset_text_field = cls._format_dataset_for_causal_lm( + # dataloader can't be None because of the dataset_required flag + smash_config.train_dataloader().dataset, # type: ignore[union-attr] + smash_config["dataset_text_field"], + ) + + # setup optimizer + if smash_config["optimizer"] == "AdamW8bit": + optim = "adamw_bnb_8bit" + elif smash_config["optimizer"] == "PagedAdamW8bit": + optim = "paged_adamw_8bit" + else: + optim = "adamw_torch" + + # setup training + model.train() + if report_every_n_samples is None: + report_every_n_samples = max(1, len(dataset) // 8) + with tempfile.TemporaryDirectory(prefix=str(smash_config["cache_dir"])) as temp_dir: + training_args = SFTConfig( + # task + dataset_text_field=dataset_text_field, + optim_target_modules=["lora"], + # batch size + per_device_train_batch_size=smash_config["training_batch_size"], + gradient_accumulation_steps=smash_config["gradient_accumulation_steps"], + # optimization + warmup_steps=100, + num_train_epochs=smash_config["num_epochs"], + learning_rate=smash_config["learning_rate"], + lr_scheduler_type="cosine", + weight_decay=0.01, + fp16=(dtype == torch.float16), + bf16=(dtype == torch.bfloat16), + optim=optim, + # logging + logging_strategy="steps", + logging_steps=report_every_n_samples, + disable_tqdm=False, + report_to=smash_config["report_to"], + # saving + run_name=f"Recovery-{recoverer}", + output_dir=temp_dir, + seed=seed, + ) + trainer = SFTTrainer( + model=model, + train_dataset=dataset, + processing_class=smash_config.tokenizer, + args=training_args, + ) + trainer.train() + + # Get the unwrapped model + model = trainer.accelerator.unwrap_model(trainer.model) + model.eval() + model = model.to(dtype=dtype) + + return model + + @staticmethod + def _format_dataset_for_causal_lm( + dataset: Dataset | torch.utils.data.Dataset, text_field: str + ) -> tuple[Dataset, Optional[str]]: + """ + Format a dataset for SFTTrainer. + + Parameters + ---------- + dataset : Dataset | torch.utils.data.Dataset + The dataset to format. + text_field : str + The text field to use. + + Returns + ------- + tuple[Dataset, Optional[str]] + The formatted dataset and dataset text field. + A text field is only provided if the dataset is a huggingface dataset not yet tokenized. + """ + column_names = dataset.column_names # type: ignore[union-attr] + if hasattr(dataset, "column_names") and "input_ids" in column_names: + # processed dataset with no need for tokenization + # remove all other columns, otherwise HF's Trainer tends to infer the wrong task from those columns + removed_columns = [col for col in column_names if col not in [text_field, "input_ids"]] + dataset = dataset.remove_columns(removed_columns) # type: ignore[union-attr] + return dataset, None + elif hasattr(dataset, "column_names") and text_field in column_names: + # raw dataset with text field + # remove all other columns, otherwise HF's Trainer tends to infer the wrong task from those columns + removed_columns = [col for col in column_names if col != text_field] + dataset = dataset.remove_columns(removed_columns) # type: ignore[union-attr] + return dataset, text_field + elif len(dataset[0]) == 2 and torch.all(dataset[0][0][1:] == dataset[0][1][:-1]): + # (input, label) format for next token prediction, input and label are the same with a single token shift + # attempt to convert dataset to a huggingface dataset + def data_generator() -> Iterator[dict[str, torch.Tensor]]: + for idx in range(len(dataset)): # type: ignore[arg-type] + data_input, label = dataset[idx] + # append last token of label to input for next token prediction + input_ids = torch.cat((data_input, label[..., -1:])) # this conversion slows finetuning a little + attention_mask = torch.ones_like(input_ids) + yield {"input_ids": input_ids, "attention_mask": attention_mask} + + dataset = Dataset.from_generator(data_generator) + return dataset, None + else: + pruna_logger.error( + "The dataset provided for recovery is not compatible. Accepted format include:\n" + " - huggingface datasets with a text field,\n" + " - huggingface datasets with an input_ids field,\n" + " - (input, label) format for next token prediction." + ) + raise ValueError(f"Expected a dataset with text or input_ids fields for LoRA recovery but got: {dataset}") diff --git a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py new file mode 100644 index 00000000..7ae655de --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py @@ -0,0 +1,321 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import torch +from ConfigSpace import Constant + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag +from pruna.algorithms.global_utils.recovery.adapters.utils import freeze_parameters +from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner +from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr +from pruna.algorithms.global_utils.recovery.utils import get_trainable_parameters +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_causal_lm, + is_flux_pipeline, + is_sana_pipeline, + is_sd_pipeline, + is_sdxl_pipeline, +) +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.logging.logger import pruna_logger + + +class PERPRecoverer(PrunaAlgorithmBase): + """ + General purpose PERP recoverer using norm, head and bias finetuning and optionally HuggingFace's LoRA. + + Parameters + ---------- + task_name : str + The name of the task to recover. + use_lora : bool + Whether to use LoRA adapters, which will not be merged and therefore slow down inference. + use_in_place : bool + Whether to use norm, bias and head finetuning which will modify the model in place. + is_distillation : bool + Whether to use distillation which requires a distillation datamodule, otherwise finetuning is used. + """ + + group_tags: list[AlgorithmTag] = [AlgorithmTag.RECOVERER] # type: ignore[attr-defined] + save_fn = SAVE_FUNCTIONS.pickled + references: dict[str, str] = { + "GitHub": "https://github.com/huggingface/peft", + "Paper": "https://arxiv.org/pdf/2312.15230", + } + processor_required: bool = False + dataset_required: bool = True + runs_on: list[str] = ["cpu", "cuda"] + + def __init__(self, task_name: str, use_lora: bool, use_in_place: bool, is_distillation: bool) -> None: + self.task_name = task_name + self.tokenizer_required = task_name == "text_to_text" # type: ignore[misc] + + if not use_lora and not use_in_place: + raise ValueError("Arguments use_lora and use_in_place cannot both be False, please use one of the two.") + self.use_lora = use_lora + self.use_in_place = use_in_place + self.is_distillation = is_distillation + # define all used types of adapters + self.adapters = [] + if self.use_in_place: + self.adapters.append("NormAdapter") + self.adapters.append("BiasAdapter") + if self.task_name == "text_to_text": + self.adapters.append("HeadAdapter") + if self.use_lora: + self.adapters.append("LoraAdapter") + + # The recoverer receives a single seed to create a seed generator to seed any adapter initialization and the + # actual distillation. We don't know at which point in the application of the algorithmth, adapters are created + # (during apply or in the pre-smash-hook) so we store a single generator here, which gets initliazed in apply or + # in the pre-smash hook (whatever is called first) and use this generator for seeding at any point during the + # application of this algorithm. + self.seed_generator: torch.Generator | None = None + + super().__init__() # self.adapters need to be set before calling get_hyperparameters + + def get_hyperparameters(self) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ + imported_modules = self.import_algorithm_packages() + adapters = [imported_modules[adapter] for adapter in self.adapters] + + hyperparameters = [] + + # collect adapters hyperparameters and add the adapter's prefix + for adapter in adapters: + adapter_hyperparams = adapter.get_hyperparameters(self.task_name) + for param in adapter_hyperparams: + param.name = f"{adapter.adapter_prefix}_{param.name}" + hyperparameters.extend(adapter_hyperparams) + + # collect finetuner hyperparameters + finetuner_hyperparams = imported_modules["Finetuner"].get_hyperparameters() + hyperparameters.extend(finetuner_hyperparams) + + seed = int(torch.randint(0, 2**32 - 1, (1,)).item()) + hyperparameters.append( + Constant( # set to constant, waiting for user-defined non-optimized hyperparameters + "seed", + seed, + meta=dict(desc="Random seed used for reproducibility."), + ) + ) + + return hyperparameters + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is compatible with PERP. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a Stable Diffusion or Stable Diffusion XL pipeline, False otherwise. + """ + if self.task_name == "text_to_image": + return is_sd_pipeline(model) or is_sdxl_pipeline(model) or is_sana_pipeline(model) or is_flux_pipeline(model) + elif self.task_name == "text_to_text": + return is_causal_lm(model) + else: + raise ValueError(f"Task name {self.task_name} is not supported for PERP recovery.") + + def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> None: + """ + Perform any necessary setup steps before the smashing process begins. + + Parameters + ---------- + model : Any + The model to prepare for smashing. + smash_config : SmashConfig + Configuration object containing algorithm settings. + """ + # Identify which components in the pipeline might need to be setup before smashing + if self.task_name == "text_to_image": + model_recovery, denoiser_attr_name = get_denoiser_attr(model) + if model_recovery is None: + pruna_logger.error("Could not infer the denoiser attribute in the pipeline, skipping recovery.") + return + else: # text_to_text + model_recovery = model + + # initialize the seed generator if it is not already done, see comment in __init__ + if self.seed_generator is None: + self.seed_generator = torch.Generator().manual_seed(smash_config["seed"]) + # prepare individual seeds for adapters + adapter_seeds = [ + int(torch.randint(0, 2**32 - 1, (1,), generator=self.seed_generator).item()) for _ in self.adapters + ] + + imported_modules = self.import_algorithm_packages() + adapters = [imported_modules[adapter] for adapter in self.adapters] + + for adapter, adapter_seed in zip(adapters, adapter_seeds): + adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") + adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Recover performances from a given model with a given config. + + Parameters + ---------- + model : Any + The model to recover from. + smash_config : SmashConfig + The configuration for the recovery. + + Returns + ------- + Any + The quantized model. + """ + # Identify which components in the pipeline will have adapters + if self.task_name == "text_to_image": + model_recovery, denoiser_attr_name = get_denoiser_attr(model) + if model_recovery is None: + pruna_logger.error("Could not infer the denoiser attribute in the pipeline, skipping recovery.") + return model + else: # text_to_text + model_recovery = model + + # freeze all parameters so we can add trainable parameters / select a subset for finetuning + freeze_parameters(model) + model_recovery.train() # activate running batch norm / dropout / etc.. as preparation for finetuning + + # store device and dtype and warn user against finetuning on CPU if necessary + device = smash_config.device + if device == "cpu" or (hasattr(device, "type") and device.type == "cpu"): + warning_if_cpu = "Model is on CPU, this is not recommended for recovery as it may take a long time." + pruna_logger.warning(warning_if_cpu) + + # initialize the seed generator if it is not already done, see comment in __init__ + if self.seed_generator is None: + self.seed_generator = torch.Generator().manual_seed(smash_config["seed"]) + # # prepare individual seeds for adapters and distllation (e.g. seeds the random initialization of LoRA) + distillation_seed = int(torch.randint(0, 2**32 - 1, (1,), generator=self.seed_generator).item()) + adapter_seeds = [ + int(torch.randint(0, 2**32 - 1, (1,), generator=self.seed_generator).item()) for _ in self.adapters + ] + + # activate adapters + imported_modules = self.import_algorithm_packages() + adapters = [imported_modules[adapter] for adapter in self.adapters] + + prefixes_used = [] + for adapter, adapter_seed in zip(adapters, adapter_seeds): + adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") + model_recovery, num_activ_param, num_skip_param = adapter.activate( + model_recovery, adapter_smash_config, seed=adapter_seed + ) + + # log skipped parameters and record which adapters were actually used + if num_skip_param > 0: + pruna_logger.warning( + f"Skipped {num_skip_param:.2e} {adapter.adapter_prefix} parameters " + "that were not trainable due to quantization." + ) + elif num_activ_param == 0: # num_skip_param = 0 too so there is no such parameter + pruna_logger.info(f"No trainable {adapter.adapter_prefix} parameters found: skipping adapter.") + else: + prefixes_used.append(adapter.adapter_prefix) + model_recovery.to(device=device) + + # check if any parameters were activated + num_trainable_params = sum(p.numel() for p in get_trainable_parameters(model)) + if num_trainable_params == 0: + pruna_logger.error("No trainable parameters were activated, skipping recovery.") + return model + else: + pruna_logger.info( + f"Recovering with PERP: {' + '.join(prefixes_used)}, totaling {num_trainable_params:.2e} parameters." + ) + + # replace the component in the pipeline + if self.task_name == "text_to_image": + setattr(model, denoiser_attr_name, model_recovery) + else: + model = model_recovery + + # finetune the model + model = imported_modules["Finetuner"].finetune(model, smash_config, distillation_seed, self.algorithm_name) + + # switch back to eval mode + model_recovery.eval() # disable dropout, set batch norm to eval mode, etc.. + freeze_parameters(model_recovery) # freeze all finetuned parameters for inference + + # remove peft wrapper to recover a model with the same type as the recoverer's input + if self.use_lora and self.task_name == "text_to_image": + base_denoiser = getattr(model, denoiser_attr_name).base_model.model + setattr(model, denoiser_attr_name, base_denoiser) + elif self.use_lora: + model = model.base_model.model + + return model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Provide a algorithm packages for the algorithm. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + from pruna.algorithms.global_utils.recovery.adapters.bias import BiasAdapter + from pruna.algorithms.global_utils.recovery.adapters.head import HeadAdapter + from pruna.algorithms.global_utils.recovery.adapters.lora import LoraAdapter + from pruna.algorithms.global_utils.recovery.adapters.norm import NormAdapter + + Finetuner: type[PrunaFinetuner] | None = None # noqa: N806 + if self.task_name == "text_to_image" and self.is_distillation: + from pruna.algorithms.global_utils.recovery.finetuners.text_to_image_distiller import ( + TextToImageDistiller as Finetuner, + ) + elif self.task_name == "text_to_image": + from pruna.algorithms.global_utils.recovery.finetuners.text_to_image_finetuner import ( + TextToImageFinetuner as Finetuner, + ) + elif self.task_name == "text_to_text" and self.is_distillation: + raise NotImplementedError("Distillation for text-to-text models is not implemented yet.") + elif self.task_name == "text_to_text": + from pruna.algorithms.global_utils.recovery.finetuners.text_to_text_finetuner import ( + TextToTextFinetuner as Finetuner, + ) + else: + raise ValueError(f"Task name {self.task_name} is not supported for PERP recovery.") + + return dict( + BiasAdapter=BiasAdapter, + HeadAdapter=HeadAdapter, + LoraAdapter=LoraAdapter, + NormAdapter=NormAdapter, + Finetuner=Finetuner, + ) diff --git a/src/pruna/algorithms/global_utils/recovery/utils.py b/src/pruna/algorithms/global_utils/recovery/utils.py new file mode 100644 index 00000000..9db70ce9 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/utils.py @@ -0,0 +1,147 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Any, Callable, List, Tuple + +import torch + + +def split_defaults(defaults: dict[str, Any]) -> tuple[dict[str, str], dict[str, int | float]]: + """ + Split the defaults into string and numeric defaults. + + Parameters + ---------- + defaults : dict[str, Any] + The defaults to split. + + Returns + ------- + tuple[dict[str, str], dict[str, int | float]] + The string and numeric defaults. + """ + string_defaults: dict[str, str] = {k: str(v) for k, v in defaults.items() if isinstance(v, str)} + numeric_defaults: dict[str, int | float] = {k: v for k, v in defaults.items() if k not in string_defaults} + return string_defaults, numeric_defaults + + +def cast_parameters(parameters: List[torch.nn.Parameter] | torch.nn.ParameterList, dtype: torch.dtype | str) -> None: + """ + Cast the parameters of the model to the given dtype. + + Parameters + ---------- + parameters : List[torch.nn.Parameter] | torch.nn.ParameterList + The parameters to cast. + dtype : torch.dtype + The dtype to cast the parameters to. + """ + for param in parameters: + param.data = param.data.to(dtype) + if hasattr(param, "grad") and param.grad is not None: + param.grad = param.grad.to(dtype) + + +def get_dtype(model: Any) -> torch.dtype: + """ + Get the dtype of the model. + + Parameters + ---------- + model : Any + The model to get the dtype from. + + Returns + ------- + torch.dtype + The dtype of the model. + """ + if hasattr(model, "dtype"): + dtype = model.dtype + return dtype + else: # fallback by looking for a float type parameter + for param in model.parameters(): + if "float" in str(param.dtype): + dtype = param.dtype + return dtype + # last resort: use the first parameter's type + return next(iter(model.parameters())).dtype + + +def get_trainable_parameters(model: Any) -> List[torch.nn.Parameter]: + """ + Get the trainable parameters of the model or pipeline. + + Parameters + ---------- + model : Any + The model or pipeline to get the trainable parameters from. + + Returns + ------- + List[torch.nn.Parameter] + The trainable parameters of the model or pipeline. + """ + if isinstance(model, torch.nn.Module): + return [param for param in model.parameters() if param.requires_grad] + + modules = [component for _, component in inspect.getmembers(model) if isinstance(component, torch.nn.Module)] + return [param for module in modules for param in module.parameters() if param.requires_grad] + + +def str_to_int(s: str) -> int: + """ + Deterministically convert a string to an integer. + + Parameters + ---------- + s : str + The string to convert to an integer. + + Returns + ------- + int + An integer obtained from the string. + """ + return int(s.encode("utf-8").hex(), 16) + + +def filter_kwargs(function: Callable, kwargs: dict[str, Any]) -> Tuple[dict[str, Any], dict[str, Any]]: + """ + Filter the kwargs of a function to separate the arguments that are valid for the function and those that are not. + + Parameters + ---------- + function : Callable + The function to filter the kwargs of. + kwargs : dict[str, Any] + The kwargs to filter. + + Returns + ------- + tuple[dict[str, Any], dict[str, Any]] + The valid and invalid kwargs. + """ + valid_kwargs = {} + invalid_kwargs = {} + signature = inspect.signature(function) + for key, value in kwargs.items(): + if key in signature.parameters: + valid_kwargs[key] = value + else: + invalid_kwargs[key] = value + return valid_kwargs, invalid_kwargs diff --git a/src/pruna/algorithms/perp.py b/src/pruna/algorithms/perp.py new file mode 100644 index 00000000..eb6a689a --- /dev/null +++ b/src/pruna/algorithms/perp.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Iterable + +from pruna.algorithms.base.tags import AlgorithmTag + +from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer + + +class TextToImagePERP(PERPRecoverer): + """ + PERP recoverer for text-to-image models. + + This recoverer is a general purpose PERP recoverer for text-to-image models using norm and bias finetuning + as well as LoRA layers. + + Parameters + ---------- + use_lora : bool + Whether to use LoRA adapters. + use_in_place : bool + Whether to use norm and bias finetuning which will modify the model in place. + """ + + algorithm_name: str = "text_to_image_perp" + tokenizer_required: bool = False + compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache", "flux_caching"] + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile", "x_fast"] + runs_on: list[str] = ["cuda"] + + def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None: + super().__init__(task_name="text_to_image", use_lora=use_lora, use_in_place=use_in_place, is_distillation=False) + + +class TextToImageInPlacePERP(TextToImagePERP): + """ + PERP recoverer for text-to-image models without LoRA adapters. + + This is the same as ``text_to_image_perp``, but without LoRA layers which add extra computations and thus slow down + the inference of the final model. + """ + + algorithm_name = "text_to_image_inplace_perp" + + def __init__(self) -> None: + super().__init__(use_lora=False) + + +class TextToImageLoRA(TextToImagePERP): + """ + LoRA recoverer for text-to-image models. + + This recoverer attaches LoRA adapters to the model and finetunes them. + """ + + algorithm_name: str = "text_to_image_lora" + + def __init__(self) -> None: + super().__init__(use_in_place=False) + + +class TextToTextPERP(PERPRecoverer): + """ + PERP recoverer for text-to-text models. + + This recoverer is a general purpose PERP recoverer for text-to-text models using norm, head and bias finetuning + as well as LoRA layers. + + Parameters + ---------- + use_lora : bool + Whether to use LoRA adapters. + use_in_place : bool + Whether to use norm, bias and head finetuning which will modify the model in place. + """ + + algorithm_name: str = "text_to_text_perp" + tokenizer_required: bool = True + compatible_before: Iterable[str | AlgorithmTag] = ["half", "quanto", "torch_dynamic"] + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile", "x_fast"] + + def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None: + super().__init__(task_name="text_to_text", use_lora=use_lora, use_in_place=use_in_place, is_distillation=False) + + +class TextToTextInPlacePERP(TextToTextPERP): + """ + PERP recoverer for text-to-text models without LoRA adapters. + + This is the same as ``text_to_text_perp``, but without LoRA layers which add extra computations and thus slow down + the inference of the final model. + """ + + algorithm_name: str = "text_to_text_inplace_perp" + + def __init__(self) -> None: + super().__init__(use_lora=False) + + +class TextToTextLoRA(TextToTextPERP): + """ + LoRA recoverer for text-to-text models. + + This recoverer attaches LoRA adapters to the model and finetunes them. + """ + + algorithm_name: str = "text_to_text_lora" + + def __init__(self) -> None: + super().__init__(use_in_place=False) diff --git a/tests/algorithms/testers/tti_inplace_perp.py b/tests/algorithms/testers/tti_inplace_perp.py new file mode 100644 index 00000000..9fdbb27a --- /dev/null +++ b/tests/algorithms/testers/tti_inplace_perp.py @@ -0,0 +1,48 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.perp import TextToImageInPlacePERP +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +# Our nightlies machine does not support efficient attention mechanisms and causes OOM errors with this test. +# This test do pass on modern architectures. +@pytest.mark.high +@pytest.mark.slow +class TestTTIInPlacePerp(AlgorithmTesterBase): + """Test the TTI InPlace Perp recovery algorithm.""" + + models = ["noref_flux_tiny_random", "noref_sd_tiny_random", "noref_sana_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = True + algorithm_class = TextToImageInPlacePERP + metrics = ["ssim"] + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) diff --git a/tests/algorithms/testers/tti_lora.py b/tests/algorithms/testers/tti_lora.py new file mode 100644 index 00000000..78b52a6a --- /dev/null +++ b/tests/algorithms/testers/tti_lora.py @@ -0,0 +1,48 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.perp import TextToImageLoRA +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +# Our nightlies machine does not support efficient attention mechanisms and causes OOM errors with this test. +# This test do pass on modern architectures. +@pytest.mark.high +@pytest.mark.slow +class TestTTILoRA(AlgorithmTesterBase): + """Test the TTI LoRA recovery algorithm.""" + + models = ["noref_flux_tiny_random", "noref_sd_tiny_random", "noref_sana_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = True + algorithm_class = TextToImageLoRA + metrics = ["lpips"] + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) diff --git a/tests/algorithms/testers/tti_perp.py b/tests/algorithms/testers/tti_perp.py new file mode 100644 index 00000000..4cbedf67 --- /dev/null +++ b/tests/algorithms/testers/tti_perp.py @@ -0,0 +1,46 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.perp import TextToImagePERP +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +@pytest.mark.slow +@pytest.mark.high +class TestTTIPerp(AlgorithmTesterBase): + """Test the TTI Perp recovery algorithm.""" + + models = ["noref_flux_tiny_random", "noref_sd_tiny_random", "noref_sana_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = True + algorithm_class = TextToImagePERP + metrics = ["cmmd"] + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) diff --git a/tests/algorithms/testers/ttt_inplace_perp.py b/tests/algorithms/testers/ttt_inplace_perp.py new file mode 100644 index 00000000..4ab3f96b --- /dev/null +++ b/tests/algorithms/testers/ttt_inplace_perp.py @@ -0,0 +1,45 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.perp import TextToTextInPlacePERP +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +@pytest.mark.slow +class TestTTTInPlacePerp(AlgorithmTesterBase): + """Test the TTT InPlace Perp recovery algorithm.""" + + models = ["noref_opt_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = True + algorithm_class = TextToTextInPlacePERP + metrics = ["perplexity"] + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) diff --git a/tests/algorithms/testers/ttt_lora.py b/tests/algorithms/testers/ttt_lora.py new file mode 100644 index 00000000..9e5bccf9 --- /dev/null +++ b/tests/algorithms/testers/ttt_lora.py @@ -0,0 +1,45 @@ +from typing import Any + +import pytest +import torch + +from pruna import SmashConfig +from pruna.algorithms.perp import TextToTextLoRA +from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules + +from .base_tester import AlgorithmTesterBase +from .utils import restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +@pytest.mark.slow +class TestTTTLoRA(AlgorithmTesterBase): + """Test the TTT LoRA recovery algorithm.""" + + models = ["noref_opt_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = True + algorithm_class = TextToTextLoRA + metrics = ["perplexity"] + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) diff --git a/tests/algorithms/testers/ttt_perp.py b/tests/algorithms/testers/ttt_perp.py new file mode 100644 index 00000000..e1f857d7 --- /dev/null +++ b/tests/algorithms/testers/ttt_perp.py @@ -0,0 +1,45 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.perp import TextToTextPERP +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +@pytest.mark.slow +class TestTTTPerp(AlgorithmTesterBase): + """Test the TTT Perp recovery algorithm.""" + + models = ["noref_opt_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = True + algorithm_class = TextToTextPERP + metrics = ["perplexity"] + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) diff --git a/tests/algorithms/testers/utils.py b/tests/algorithms/testers/utils.py index ba86cb4d..322dcc7d 100644 --- a/tests/algorithms/testers/utils.py +++ b/tests/algorithms/testers/utils.py @@ -1,5 +1,15 @@ from typing import Any +from pruna.config.smash_config import SmashConfig + + +def restrict_recovery_time(smash_config: SmashConfig, algorithm_name: str) -> None: + """Restrict the recovery time to a few batches to test iteration multiple time but as few as possible.""" + smash_config[f"{algorithm_name}_training_batch_size"] = 1 + smash_config[f"{algorithm_name}_num_epochs"] = 1 + # restrict the number of train and validation samples in the dataset + smash_config.data.limit_datasets((2, 1, 1)) # 2 train, 1 val, 1 test + def get_model_sparsity(model: Any) -> float: """Get the sparsity of the model.""" From 5b2ac12a43235a2185087c1b9af53ee5397d0e7e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 6 Jan 2026 13:23:36 +0000 Subject: [PATCH 2/4] adding revocery algorithms with distillation --- src/pruna/algorithms/base/tags.py | 4 + src/pruna/algorithms/distillation_perp.py | 73 ++ .../finetuners/text_to_image_distiller.py | 633 ++++++++++++++++++ src/pruna/algorithms/perp.py | 19 +- .../data/diffuser_distillation_data_module.py | 351 ++++++++++ .../testers/tti_distillation_inplace_perp.py | 56 ++ .../testers/tti_distillation_lora.py | 56 ++ .../testers/tti_distillation_perp.py | 56 ++ tests/algorithms/testers/utils.py | 14 + 9 files changed, 1259 insertions(+), 3 deletions(-) create mode 100644 src/pruna/algorithms/distillation_perp.py create mode 100644 src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py create mode 100644 src/pruna/data/diffuser_distillation_data_module.py create mode 100644 tests/algorithms/testers/tti_distillation_inplace_perp.py create mode 100644 tests/algorithms/testers/tti_distillation_lora.py create mode 100644 tests/algorithms/testers/tti_distillation_perp.py diff --git a/src/pruna/algorithms/base/tags.py b/src/pruna/algorithms/base/tags.py index 2995effb..39f1d584 100644 --- a/src/pruna/algorithms/base/tags.py +++ b/src/pruna/algorithms/base/tags.py @@ -68,6 +68,10 @@ class AlgorithmTag(Enum): "recoverer", "Recovery restores the performance of a model after compression.", ) + DISTILLER = ( + "distiller", + "Distillation trains a smaller, simpler model to mimic a larger, more complex model.", + ) def __init__(self, name: str, description: str): """ diff --git a/src/pruna/algorithms/distillation_perp.py b/src/pruna/algorithms/distillation_perp.py new file mode 100644 index 00000000..6f81d8fc --- /dev/null +++ b/src/pruna/algorithms/distillation_perp.py @@ -0,0 +1,73 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Iterable + +from pruna.algorithms.base.tags import AlgorithmTag +from pruna.algorithms.perp import PERPRecoverer + + +class TextToImagePERPDistillation(PERPRecoverer): + """ + PERP distillation recoverer for text-to-image models. + + This recoverer is a general purpose PERP recoverer for text-to-image models using norm and bias finetuning + as well as LoRA layers. + + Parameters + ---------- + use_lora : bool + Whether to use LoRA adapters. + use_in_place : bool + Whether to use norm and bias finetuning which will modify the model in place. + """ + + group_tags: list[AlgorithmTag] = [AlgorithmTag.DISTILLER, AlgorithmTag.RECOVERER] # type: ignore[attr-defined] + algorithm_name = "text_to_image_distillation_perp" + tokenizer_required = False + compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache", "flux_caching"] + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile", "x_fast"] + runs_on: list[str] = ["cuda"] + + def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None: + super().__init__(task_name="text_to_image", use_lora=use_lora, use_in_place=use_in_place, is_distillation=True) + + +class TextToImageInPlacePERPDistillation(TextToImagePERPDistillation): + """ + PERP distillation recoverer for text-to-image models without LoRA adapters. + + This is the same as ``text_to_image_distillation_perp``, but without LoRA layers which add extra computations and + thus slow down the inference of the final model. + """ + + algorithm_name = "text_to_image_distillation_inplace_perp" + + def __init__(self) -> None: + super().__init__(use_lora=False, use_in_place=True) + + +class TextToImageLoraDistillation(TextToImagePERPDistillation): + """ + LoRA distillation recoverer for text-to-image models. + + This recoverer attaches LoRA adapters to the model and uses them for distillation. + """ + + algorithm_name = "text_to_image_distillation_lora" + + def __init__(self) -> None: + super().__init__(use_lora=True, use_in_place=False) diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py new file mode 100644 index 00000000..50c54791 --- /dev/null +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py @@ -0,0 +1,633 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools +import random +from typing import Any, List, Literal + +import pytorch_lightning as pl +import torch +from diffusers.optimization import get_scheduler +from diffusers.utils import BaseOutput +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.seed import isolate_rng + +try: + from bitsandbytes.optim import AdamW8bit # type: ignore[import-untyped] +except ImportError: + AdamW8bit = None + +import pathlib + +from ConfigSpace import ( + CategoricalHyperparameter, + Constant, + UniformFloatHyperparameter, + UniformIntegerHyperparameter, +) + +from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner +from pruna.algorithms.global_utils.recovery.finetuners.diffusers import utils +from pruna.algorithms.global_utils.recovery.finetuners.diffusers.distillation_arg_utils import ( + get_latent_replacement_fn, +) +from pruna.algorithms.global_utils.recovery.utils import ( + filter_kwargs, + get_dtype, + get_trainable_parameters, + split_defaults, +) +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.data.diffuser_distillation_data_module import DiffusionDistillationDataModule +from pruna.engine.utils import get_device, get_device_type +from pruna.logging.logger import pruna_logger + + +class TextToImageDistiller(PrunaFinetuner): + """Distiller for text-to-image models.""" + + @classmethod + def get_hyperparameters(cls, **override_defaults) -> List: + """ + Configure all method-specific hyperparameters with ConfigSpace. + + Parameters + ---------- + **override_defaults : dict + The hyperparameters to override. + + Returns + ------- + list + The hyperparameters. + """ + defaults = { + "training_batch_size": 0, # 0: all steps of the diffusion process are used + "gradient_accumulation_steps": 1, + "num_epochs": 1.0, + "validate_every_n_epoch": 1.0, + "learning_rate": 1e-4, + "weight_decay": 1e-2, + "report_to": "none", + "optimizer": "AdamW8bit" if torch.cuda.is_available() else "AdamW", # AdamW8bit from BnB assumes CUDA + "lr_decay": 0.5, + "warmup_steps": 0, + } + defaults.update(override_defaults) + string_defaults, numeric_defaults = split_defaults(defaults) + + return [ + # not a true batch size, but suggests the increase of VRAM coming with higher batch_size + UniformIntegerHyperparameter( + "training_batch_size", + lower=0, + upper=4096, + default_value=numeric_defaults["training_batch_size"], + meta=dict(desc="Number of steps from each diffusion process to use for distillation."), + ), + UniformIntegerHyperparameter( + "gradient_accumulation_steps", + lower=1, + upper=1024, + default_value=numeric_defaults["gradient_accumulation_steps"], + meta=dict(desc="Number of captions processed to estimate each gradient step."), + ), + UniformIntegerHyperparameter( + "num_epochs", + lower=0, + upper=4096, + default_value=numeric_defaults["num_epochs"], + meta=dict(desc="Number of epochs for distillation."), + ), + UniformFloatHyperparameter( + "validate_every_n_epoch", + lower=0.0, + upper=4096.0, + default_value=numeric_defaults["validate_every_n_epoch"], + meta=dict( + desc="Number of epochs between each round of validation and model checkpointing. " + "If the value is between 0 and 1, validation will be performed multiple times per epoch, " + "e.g. 1/8 will result in 8 validations per epoch." + ), + ), + UniformFloatHyperparameter( + "learning_rate", + lower=0.0, + upper=1.0, + default_value=numeric_defaults["learning_rate"], + meta=dict(desc="Learning rate for distillation."), + ), + Constant("weight_decay", numeric_defaults["weight_decay"]), + # report_to: for consistency with text-to-text-lora but wandb and tensorboard are not supported yet + Constant("report_to", string_defaults["report_to"]), + Boolean( + "use_cpu_offloading", + default=False, + meta=dict(desc="Whether to use CPU offloading for distillation."), + ), + CategoricalHyperparameter( + "optimizer", + choices=["AdamW8bit", "AdamW", "Adam"], + default_value=string_defaults["optimizer"], + meta=dict(desc="Which optimizer to use for distillation."), + ), + UniformFloatHyperparameter( + "lr_decay", + lower=0.0, + upper=1.0, + default_value=numeric_defaults["lr_decay"], + meta=dict(desc="Learning rate decay, applied at each epoch."), + ), + UniformIntegerHyperparameter( + "warmup_steps", + lower=0, + upper=2**14, + default_value=numeric_defaults["warmup_steps"], + meta=dict(desc="Number of warmup steps for the learning rate scheduler."), + ), + ] + + @classmethod + def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any: + """ + Train the model previously activated parameters on distillation data extracted from the original model. + + Parameters + ---------- + pipeline : Any + The pipeline containing components to finetune. + smash_config : SmashConfigPrefixWrapper + The configuration for the finetuner. + seed : int + The seed to use for reproducibility. + recoverer : str + The recoverer used, i.e. the selection of parameters to finetune. This is only used for logging purposes. + + Returns + ------- + Any + The finetuned pipeline. + """ + if not isinstance(smash_config.data, DiffusionDistillationDataModule): + raise ValueError( + f"DiffusionDistillation data module is required for distillation, but got {smash_config.data}." + ) + + dtype = get_dtype(pipeline) + device = get_device(pipeline) + try: + lora_r = smash_config["lora_r"] + except KeyError: + lora_r = 0 + + # split seed into two rng: generator for the dataloader and a seed for the training part + generator = torch.Generator().manual_seed(seed) + fit_seed = int(torch.randint(0, 2**32 - 1, (1,), generator=generator).item()) + + # Dataloaders (batch size is used to decide how many diffusion steps to backprop) + train_dataloader = smash_config.train_dataloader(generator=generator) + val_dataloader = smash_config.val_dataloader() + + # Finetune the model + trainable_distiller = DistillerTL( + pipeline, + smash_config["training_batch_size"], + smash_config["gradient_accumulation_steps"], + smash_config["optimizer"], + smash_config["learning_rate"], + smash_config["lr_decay"], + smash_config["warmup_steps"], + smash_config["weight_decay"], + lora_r, + recoverer, + smash_config["use_cpu_offloading"], + pipeline_kwargs=smash_config.data.pipeline_kwargs, + ) + # make directory for logs and checkpoints + model_path = pathlib.Path(smash_config.cache_dir) / "recovery" + model_path.mkdir(parents=True) + + early_stopping = EarlyStopping(monitor="validation_loss", patience=3, mode="min", check_finite=True) + checkpoint_callback = ModelCheckpoint( + monitor="validation_loss", + save_top_k=1, + mode="min", + every_n_epochs=1, + dirpath=model_path, + filename="model-{epoch:02d}-{validation_loss:.4f}", + ) + callbacks = [early_stopping, checkpoint_callback] + + if smash_config["validate_every_n_epoch"] >= 1.0: + check_val_every_n_epoch = int(smash_config["validate_every_n_epoch"]) + val_check_interval = None + else: + check_val_every_n_epoch = 1 + val_check_interval = smash_config["validate_every_n_epoch"] + + precision: Literal["16-true", "bf16-true", "32"] + if dtype == torch.float16: + precision = "16-true" + elif dtype == torch.bfloat16: + precision = "bf16-true" + else: + precision = "32" + + accelerator = get_device_type(pipeline) + if accelerator == "accelerator": + accelerator = "auto" + trainer = pl.Trainer( + callbacks=callbacks, + max_epochs=smash_config["num_epochs"], + inference_mode=False, + precision=precision, + log_every_n_steps=smash_config["gradient_accumulation_steps"], + check_val_every_n_epoch=check_val_every_n_epoch, + val_check_interval=val_check_interval, + logger=False, + num_sanity_val_steps=0, # the train.validate already acts as a sanity check + accelerator=accelerator, + ) + + with isolate_rng(): + pl.seed_everything(fit_seed) + trainer.validate(trainable_distiller, dataloaders=val_dataloader, verbose=False) + trainer.fit(trainable_distiller, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) + + # Loading the best checkpoint is slow and currently creates conflicts with some quantization algorithms, + # e.g. diffusers_int8. Skipping calling DenoiserTL.load_from_checkpoint for now. + return pipeline.to(device) + + +class DistillerTL(pl.LightningModule): + """ + Pipeline in LightningModule format for distilling the denoiser. + + Parameters + ---------- + pipeline : Any + The pipeline to finetune. + batch_size : int + The batch size to use for distillation, used to extract a subset of steps from each diffusion process. + gradient_accumulation_steps : int + The number of prompts processed to estimate each gradient step. + optimizer_name : str + The name of the optimizer to use, options are "AdamW8bit", or optimizers from torch.optim. + learning_rate : float + The learning rate to use for finetuning. + lr_decay : float + The learning rate decay to use for finetuning. + warmup_steps : int + The number of warmup steps to use for finetuning. + weight_decay : float + The weight decay to use for finetuning. + lora_r : int + The rank of the LoRA matrices. + recoverer : str + The recoverer used, i.e. the selection of parameters to finetune. This is only used for logging purposes. + use_cpu_offloading : bool, optional + Whether to use CPU offloading for finetuning. + pipeline_kwargs : dict[str, Any], optional + Additional keyword arguments to pass to the pipeline, such as `guidance_scale` or `num_inference_steps`. + """ + + def __init__( + self, + pipeline: Any, + batch_size: int, + gradient_accumulation_steps: int, + optimizer_name: str, + learning_rate: float, + lr_decay: float, + warmup_steps: int, + weight_decay: float, + lora_r: int, + recoverer: str, + use_cpu_offloading: bool = False, + pipeline_kwargs: dict[str, Any] = {}, + ): + super().__init__() + + self.pipeline = pipeline + self.latent_replacement_fn = get_latent_replacement_fn(pipeline) + + self.batch_size = batch_size + self.gradient_accumulation_steps = gradient_accumulation_steps + self.optimizer_name = optimizer_name + self.learning_rate = learning_rate + self.lr_decay = lr_decay + self.warmup_steps = warmup_steps + self.weight_decay = weight_decay + self.lora_r = lora_r + self.recoverer = recoverer + self.use_cpu_offloading = use_cpu_offloading + self.pipeline_kwargs = pipeline_kwargs + self.save_hyperparameters(ignore=["pipeline"]) + # register the denoiser explicitly to work with self.parameters() and other LightningModule methods + self.denoiser, _ = utils.get_denoiser_attr(pipeline) + self.num_previous_steps = 0 + if self.denoiser is None: + raise ValueError("Could not find the denoiser in the pipeline.") + num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + self.log("num_trainable_params", num_trainable_params) + + # basic logging + self.val_losses: List[torch.Tensor] = [] + self.train_losses: List[torch.Tensor] = [] + self.automatic_optimization = False + + def forward( + self, + caption: str, + latent_inputs: torch.Tensor, + latent_targets: torch.Tensor, + seed: int, + active_steps: list[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the denoised latents from the input latents recorded at all timesteps. + + Parameters + ---------- + caption : str + The caption to use for the denoiser. + latent_inputs : torch.Tensor + The inputs of the pipeline with shape (number_of_steps, *latent_shape) recorded in the distillation dataset. + latent_targets : torch.Tensor + The outputs of the pipeline with shape (number_of_steps, *latent_shape) recorded in the distillation dataset. + seed : int + The seed to used when recording the distillation dataset. + active_steps : list[int] | None, optional + The steps to use for the distillation. If None (default), all steps are used. + + Returns + ------- + torch.Tensor + The denoised latents at each active step. + torch.Tensor + The loss for each active diffusion step. + """ + # some variables accessible within the denoiser's monkey patched forward + self.num_previous_steps = 0 + latent_outputs = [] + diffusion_step_losses = [] + original_forward = self.denoiser.forward # type: ignore[union-attr] + + is_training = active_steps is not None + is_first_training_step = is_training and self.num_previous_steps == 0 + + @functools.wraps(original_forward) + def distillation_forward(*args, **kwargs): + if self.use_cpu_offloading and is_first_training_step: + utils.move_secondary_components(self.pipeline, "cpu") + + # current denoising_step is self.num_previous_steps + recorded_input = latent_inputs[self.num_previous_steps] + + # select which steps to record: all during validation, only trained steps during training + if active_steps is None or self.num_previous_steps in active_steps: + with torch.set_grad_enabled(is_training): # diffusers disable gradients, re-enable them for training + args, kwargs = self.latent_replacement_fn(recorded_input, args, kwargs) + output = original_forward(*args, **kwargs) + latent_output = ( + output["sample"] if ("return_dict" in kwargs and kwargs["return_dict"]) else output[0] + ) + loss = self.loss(latent_output, latent_targets[self.num_previous_steps]) + if is_training: + accumulation_normalized_loss = loss / (len(active_steps) * self.gradient_accumulation_steps) + self.manual_backward(accumulation_normalized_loss) + diffusion_step_losses.append(loss) + latent_outputs.append(latent_output) + + # recreate the expected output format + if "return_dict" in kwargs and kwargs["return_dict"]: + recorded_output = self._get_denoiser_output_object(latent_targets[self.num_previous_steps]) + else: + recorded_output = (latent_targets[self.num_previous_steps],) + + self.num_previous_steps += 1 + return recorded_output + + self.denoiser.forward = distillation_forward # type: ignore[union-attr] + + # Run the pipeline on the recorded latents and collect the outputs + _ = self.pipeline(caption, generator=torch.Generator().manual_seed(seed), **self.pipeline_kwargs) + stacked_latent_outputs = torch.stack(latent_outputs, dim=0) + stacked_diffusion_step_losses = torch.stack(diffusion_step_losses, dim=0) + + # Restore the original forward, reversing cpu_offloading will be done only after the gradient backward + self.denoiser.forward = original_forward # type: ignore[union-attr] + + return stacked_latent_outputs, stacked_diffusion_step_losses + + def training_step(self, batch: tuple[list[str], torch.Tensor, torch.Tensor, list[int]], batch_idx: int): + """ + Compute a single-step loss from the denoiser on training data. + + Parameters + ---------- + batch : tuple[list[str], torch.Tensor, torch.Tensor, int] + The batch of (captions, latent_inputs, latent_targets, seed). + batch_idx : int + The index of the batch. + + Returns + ------- + dict[str, torch.Tensor] + The single-step training loss. + """ + opt = self.optimizers() + opt.zero_grad() # type: ignore[union-attr] + + captions, latent_inputs, latent_targets, seeds = batch + assert len(captions) == 1 # only a batch size of 1 corresponding to a full diffusion process is supported + caption, latent_inputs, latent_targets, seed = captions[0], latent_inputs[0], latent_targets[0], seeds[0] + + self.pipeline.set_progress_bar_config(disable=True) + if self.use_cpu_offloading: + # avoids a bug when using gradient_accumulation_steps > 1 because the on_after_backward hasn't run yet + utils.move_secondary_components(self.pipeline, self.device) + + diffusion_steps = latent_inputs.shape[0] + trained_steps = ( + random.sample(range(diffusion_steps), min(self.batch_size, diffusion_steps)) + if self.batch_size > 0 + else list(range(diffusion_steps)) + ) + + latent_outputs, diffusion_step_losses = self.forward( + caption, latent_inputs, latent_targets, seed, active_steps=trained_steps + ) + loss = diffusion_step_losses.mean() + if (batch_idx + 1) % self.gradient_accumulation_steps == 0 or self.trainer.is_last_batch: + self.clip_gradients(opt, gradient_clip_val=1.0, gradient_clip_algorithm="norm") # type: ignore[arg-type] + opt.step() # type: ignore[union-attr] + + if self.trainer.is_last_batch: + lr_schedulers = self.lr_schedulers() + if lr_schedulers: + if isinstance(lr_schedulers, list): + for scheduler in lr_schedulers: + scheduler.step() # type: ignore[call-arg] + else: + lr_schedulers.step() + + self.log("train_loss", loss) + self.train_losses.append(loss) + + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + """ + Compute a single-step loss from the denoiser on validation data. + + Parameters + ---------- + batch : tuple[list[str], torch.Tensor] + The batch of (captions, images). + batch_idx : int + The index of the batch. + + Returns + ------- + dict[str, torch.Tensor] + The single-step validation loss. + """ + caption, latent_inputs, latent_targets, seed = batch + assert len(caption) == 1 # only a batch size of 1 corresponding to a full diffusion process is supported + caption, latent_inputs, latent_targets, seed = caption[0], latent_inputs[0], latent_targets[0], seed[0] + + self.pipeline.set_progress_bar_config(disable=True) + + latent_outputs, diffusion_step_losses = self.forward(caption, latent_inputs, latent_targets, seed) + loss = diffusion_step_losses.mean() + self.log("validation_loss", loss) + self.val_losses.append(loss) + + # no need to do CPU offloading in evaluation since gradients are not computed + return {"loss": loss} + + def on_train_epoch_end(self): + """Log the train loss.""" + loss = torch.stack(self.train_losses).mean() + pruna_logger.info(f"{self.recoverer} - epoch {self.current_epoch} - train loss: {loss:.3e}") + self.train_losses.clear() + + def on_validation_epoch_end(self): + """Log the validation loss.""" + if self.trainer.sanity_checking: + return + loss = torch.stack(self.val_losses).mean() + epoch_descr = "before distillation" if self.trainer.global_step == 0 else f"epoch {self.current_epoch}" + pruna_logger.info(f"{self.recoverer} - {epoch_descr} - validation loss: {loss:.3e}") + self.val_losses.clear() + + def on_after_backward(self): + """Move the secondary components to the device after backward is done.""" + if self.use_cpu_offloading: + # ensure the cpu_offloading is reversed, so the validation doesn't have to protect against it + utils.move_secondary_components(self.pipeline, self.device) + + def loss(self, model_pred, target): + """ + Compute the denoising loss. + + Parameters + ---------- + model_pred : torch.Tensor + The predicted latents. + target : torch.Tensor + The target latents. + + Returns + ------- + torch.Tensor + The MSE loss between the predicted and target latents. + """ + return torch.nn.functional.mse_loss(model_pred, target) + + def configure_optimizers(self): + """ + Configure the optimizer. + + Returns + ------- + torch.optim.Optimizer + The optimizer. + """ + kwargs = {"eps": 1e-7} if self.trainer.precision in [16, "16", "16-true"] else {} + kwargs["lr"] = self.learning_rate + kwargs["weight_decay"] = self.weight_decay + + optimizer_cls: type[torch.optim.Optimizer] + if self.optimizer_name == "AdamW8bit": + if self.device == "cpu" or isinstance(self.device, torch.device) and self.device.type == "cpu": + pruna_logger.warning("AdamW8bit is not supported on CPU, continuing with AdamW from torch.") + optimizer_cls = torch.optim.AdamW + elif AdamW8bit is None: + pruna_logger.warning( + "Recovery with AdamW8bit requires bitsandbytes to be installed, continuing with AdamW from torch." + ) + optimizer_cls = torch.optim.AdamW + else: + optimizer_cls = AdamW8bit + else: + queried_optimizer_cls = getattr(torch.optim, self.optimizer_name) + if issubclass(queried_optimizer_cls, torch.optim.Optimizer): + optimizer_cls = queried_optimizer_cls + else: + raise ValueError(f"Invalid optimizer: {self.optimizer_name}") + + finetune_params = get_trainable_parameters(self.pipeline) + used_kwargs, unused_kwargs = filter_kwargs(optimizer_cls.__init__, kwargs) + if unused_kwargs: + pruna_logger.warning(f"Unused optimizer arguments: {list(unused_kwargs.keys())}") + optimizer = optimizer_cls(finetune_params, **used_kwargs) + + lr_scheduler: torch.optim.lr_scheduler.LRScheduler + if self.warmup_steps > 0 and self.lr_decay < 1.0: + raise ValueError("Warmup steps and lr_decay cannot both be set for now.") + elif self.warmup_steps > 0: + lr_scheduler = get_scheduler( + name="constant_with_warmup", + optimizer=optimizer, + num_warmup_steps=self.warmup_steps, + ) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] + elif self.lr_decay < 1.0: + lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.lr_decay) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "epoch"}] + else: + return [optimizer], [] + + def _get_denoiser_output_object(self, output_tensor: torch.Tensor) -> BaseOutput: + """ + Wrap the output tensor in the BaseOutput class expected by the pipeline. + + Parameters + ---------- + output_tensor : torch.Tensor + The output tensor from the denoiser. + + Returns + ------- + BaseOutput + The wrapped output tensor. + """ + if not hasattr(self, "_denoiser_output_class"): # lazy initialization + self._denoiser_output_class = utils.get_denoiser_output_class(self.denoiser) + return self._denoiser_output_class(sample=output_tensor) diff --git a/src/pruna/algorithms/perp.py b/src/pruna/algorithms/perp.py index eb6a689a..617bd8c0 100644 --- a/src/pruna/algorithms/perp.py +++ b/src/pruna/algorithms/perp.py @@ -1,13 +1,26 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from typing import Iterable from pruna.algorithms.base.tags import AlgorithmTag - from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer -class TextToImagePERP(PERPRecoverer): +class TextToImagePERP(PERPRecoverer): """ PERP recoverer for text-to-image models. @@ -97,7 +110,7 @@ def __init__(self) -> None: super().__init__(use_lora=False) -class TextToTextLoRA(TextToTextPERP): +class TextToTextLoRA(TextToTextPERP): """ LoRA recoverer for text-to-text models. diff --git a/src/pruna/data/diffuser_distillation_data_module.py b/src/pruna/data/diffuser_distillation_data_module.py new file mode 100644 index 00000000..7a5dcedc --- /dev/null +++ b/src/pruna/data/diffuser_distillation_data_module.py @@ -0,0 +1,351 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools +import json +from pathlib import Path +from typing import Any, List, Optional, Tuple + +import torch +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from pruna.algorithms.global_utils.recovery.finetuners.diffusers.distillation_arg_utils import ( + get_latent_extractor_fn, +) +from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr +from pruna.data.pruna_datamodule import PrunaDataModule + + +class DiffusionDistillationDataModule(PrunaDataModule): + """ + A distillation datamodule containing captions, random seeds and input-output pairs during the diffusion process. + + Parameters + ---------- + pipeline : Any + The diffusion pipeline used to generate the distillation data. The diffuser is expected to have a denoiser + (either `unet` or `transformer` attribute) whose output is either a tuple with latent output in index 0, or a + dict-like object with the latent output in the `sample` key. + caption_datamodule : pl.LightningDataModule + The caption datamodule. Each batch is expected to be an iterable of captions or a tuple whose first + element is an iterable of captions. + save_path : str | Path + The path to save the distillation data. + seed : int | None + The random seed used to generate unique ids and random seeds for the distillation data. + If None, a random seed will be generated. This seed will be saved in the `parameters.json` file for future runs. + pipeline_kwargs : dict[str, Any] + Additional keyword arguments to pass to the pipeline, such as `guidance_scale` or `num_inference_steps`. + """ + + def __init__( + self, + pipeline: Any, + caption_datamodule: PrunaDataModule, + save_path: str | Path = "distillation_data", + seed: int | None = None, + pipeline_kwargs: dict[str, Any] = {}, + ): + self.pipeline = pipeline + self.pipeline_kwargs = pipeline_kwargs + self.collection_helper: InternalStateCollectionHelper | None = InternalStateCollectionHelper(self.pipeline) + + self.caption_datamodule = caption_datamodule + self.save_path = Path(save_path) + + if seed is None: + param_path = self.save_path / "parameters.json" + if param_path.exists(): # load previously saved seed + with open(param_path, "r") as f: + seed = json.load(f)["seed"] + if not isinstance(seed, int): + raise ValueError(f"Seed must be an integer, but got {seed} from the parameters.json file.") + else: + self.seed = seed + else: + self.seed = _get_random_seed() + else: + self.seed = seed + generator = torch.Generator().manual_seed(self.seed) + self.dataloader_generators = { + subset: torch.Generator().manual_seed(_get_random_seed(generator)) for subset in ["train", "val", "test"] + } + self.seed_making_generators = { + subset: torch.Generator().manual_seed(_get_random_seed(generator)) for subset in ["train", "val", "test"] + } + + train_filenames, val_filenames, test_filenames = self.prepare_distillation_dataset() + + super().__init__( + train_ds=DiffusionDistillationDataset(self.save_path / "train", train_filenames), + val_ds=DiffusionDistillationDataset(self.save_path / "val", val_filenames), + test_ds=DiffusionDistillationDataset(self.save_path / "test", test_filenames), + collate_fn=DiffusionDistillationDataset.collate_fn, + dataloader_args={}, + ) + + def prepare_distillation_dataset(self) -> Tuple[List[str], List[str], List[str]]: + """ + Prepare the distillation data. + + Returns + ------- + Tuple[List[str], List[str], List[str]] + The filenames of the train, val and test datasets. + """ + if self.pipeline is None or self.collection_helper is None: + # this can happen because those attributes are set to None at the end of this method + raise ValueError("prepare_distillation_dataset() can only be called once.") + + self.collection_helper.enable() + + # save progress bar state to restore it later + if hasattr(self.pipeline, "_progress_bar_config"): + progress_bar_state = dict(self.pipeline._progress_bar_config) + else: + progress_bar_state = {} + self.pipeline.set_progress_bar_config(disable=True) + + self.save_path.mkdir(exist_ok=True, parents=True) + parameters = { + "pipeline_kwargs": self.pipeline_kwargs, + "seed": self.seed, + } + with Path(self.save_path / "parameters.json").open("w") as f: + json.dump(parameters, f) + + train_filenames = self._prepare_one_dataset( + self.caption_datamodule.train_dataloader(generator=self.dataloader_generators["train"]), + "train", + ) + val_filenames = self._prepare_one_dataset( + self.caption_datamodule.val_dataloader(generator=self.dataloader_generators["val"]), + "val", + ) + test_filenames = self._prepare_one_dataset( + self.caption_datamodule.test_dataloader(generator=self.dataloader_generators["test"]), + "test", + ) + self.pipeline.set_progress_bar_config(**progress_bar_state) + self.collection_helper.disable() + + # pipeline should not be needed by this module anymore, so we drop the reference so it can be deleted elsewhere + self.collection_helper = None + self.pipeline = None + return train_filenames, val_filenames, test_filenames + + def _prepare_one_dataset( + self, + dataloader: Optional[DataLoader], + subdir_name: str, + ) -> List[str]: + """ + Setup a single dataset and save it to the path. + + Parameters + ---------- + dataloader : Optional[DataLoader] + The dataloader to use to prepare the dataset. + subdir_name : str + The name of the subdirectory to save the dataset to, in ["train", "val", "test"]. + + Returns + ------- + List[str] + The filenames of the dataset. + """ + Path(self.save_path / subdir_name).mkdir(exist_ok=True, parents=True) + desc = f"Prepare {subdir_name} distillation dataset" + filenames: List[str] = [] + + for batch in tqdm(dataloader, desc=desc): + captions = batch if isinstance(batch[0], str) else batch[0] + for caption in captions: + filename = f"{len(filenames)}.pt" + self._prepare_one_sample(filename, caption, subdir_name) + filenames.append(filename) + return filenames + + @torch.no_grad() + def _prepare_one_sample(self, filename: str, caption: str, subdir_name: str) -> None: + """ + Prepare a single sample and save it to the path. + + Parameters + ---------- + filename : str + The filename of the sample. + caption : str + The caption of the sample. + subdir_name : str + The name of the subdirectory to save the sample to, in ["train", "val", "test"]. + """ + assert ( + self.pipeline is not None and self.collection_helper is not None + ), "prepare_one_sample() can only be called once." + + seed = _get_random_seed(self.seed_making_generators[subdir_name]) + filepath = self.save_path / subdir_name / filename + if filepath.exists(): + return # file was generated in a previous run + + self.collection_helper.new_sample() + self.pipeline(caption, generator=torch.Generator().manual_seed(seed), **self.pipeline_kwargs) + inputs, outputs = self.collection_helper.get_sample() + + sample = { + "caption": caption, + "inputs": inputs, + "outputs": outputs, + "seed": seed, + } + torch.save(sample, filepath) + + +class DiffusionDistillationDataset(Dataset): + """ + Dataset for distilling a diffusion pipeline, containing captions, latent inputs, latent outputs and seeds. + + Parameters + ---------- + path : Path + The path to the distillation data. + filenames : Optional[List[str]] + The filenames to load from the path. If None, all files in the path will be loaded. + """ + + def __init__(self, path: Path, filenames: Optional[List[str]] = None): + self.path = path + if filenames is None: + self.filenames = [p.name for p in self.path.iterdir()] + else: + self.filenames = filenames + + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + return len(self.filenames) + + def __getitem__(self, idx: int) -> Tuple[str, torch.Tensor, torch.Tensor, int]: + """Get an item from the dataset.""" + filepath = self.path / self.filenames[idx] + # This is the most generic way to load the data, but may cause a bottleneck because of continuous disk access + # Loading the whole dataset into memory is often possible given the typically small size of distillation datasets + # This can be explored if this is identified as a causing a latency bottleneck + sample = torch.load(filepath) + return sample["caption"], sample["inputs"], sample["outputs"], sample["seed"] + + @staticmethod + def collate_fn( + samples: List[Tuple[str, torch.Tensor, torch.Tensor, int]], + ) -> Tuple[List[str], torch.Tensor, torch.Tensor, List[int]]: + """ + Collate the samples into a batch. + + Parameters + ---------- + samples : List[Tuple[str, torch.Tensor, torch.Tensor, int]] + The samples to collate, composed of a caption, a latent input, a latent output and a seed. + + Returns + ------- + Tuple[List[str], torch.Tensor, torch.Tensor, List[int]] + The collated samples. + """ + captions = [sample[0] for sample in samples] + inputs = torch.stack([sample[1] for sample in samples], dim=0) + outputs = torch.stack([sample[2] for sample in samples], dim=0) + seeds = [sample[3] for sample in samples] + return captions, inputs, outputs, seeds + + +class InternalStateCollectionHelper: + """ + Helper class to collect internal states from the pipeline. + + When enabled, the denoiser's forward will be monkey patched to save its inputs and outputs. + They can be collected by calling `new_sample`, running the pipeline and then calling `get_sample`. + The helper can be disabled to restore the original forward. + + Parameters + ---------- + pipeline : Any + The pipeline to use as example for distillation. + """ + + def __init__(self, pipeline: Any) -> None: + denoiser, _ = get_denoiser_attr(pipeline) + if denoiser is None: + raise ValueError("Could not find a denoiser in the pipeline.") + self.denoiser: torch.nn.Module = denoiser + self.latent_extractor_fn = get_latent_extractor_fn(pipeline) + + self.original_forward = self.denoiser.forward + self.inputs: List[torch.Tensor] = [] + self.outputs: List[torch.Tensor] = [] + + def new_sample(self) -> None: + """Reset the state of the forward hook before calling the pipeline, must be called before each new sample.""" + self.inputs = [] + self.outputs = [] + + def get_sample(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Return the saved internal states after the pipeline has been called. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + The saved inputs and outputs. + """ + inputs = torch.stack(self.inputs, dim=0) + outputs = torch.stack(self.outputs, dim=0) + return inputs, outputs + + def enable(self) -> None: + """Enable the collection of internal states when running the pipeline.""" + + @functools.wraps(self.denoiser.forward) + def forward(*args, **kwargs): + latent = self.latent_extractor_fn(args, kwargs) + self.inputs.append(latent) + results = self.original_forward(*args, **kwargs) + output = results["sample"] if ("return_dict" in kwargs and kwargs["return_dict"]) else results[0] + self.outputs.append(output) + return results + + self.denoiser.forward = forward + + def disable(self) -> None: + """Disable the collection of internal states.""" + self.denoiser.forward = self.original_forward + + +def _get_random_seed(generator: torch.Generator | None = None) -> int: + """ + Randomly generate a random seed. + + Parameters + ---------- + generator : torch.Generator | None + The generator to use to generate the seed. If None, the current rng state will be used. + + Returns + ------- + int + The generated seed. + """ + return int(torch.randint(0, 2**32 - 1, (1,), generator=generator).item()) diff --git a/tests/algorithms/testers/tti_distillation_inplace_perp.py b/tests/algorithms/testers/tti_distillation_inplace_perp.py new file mode 100644 index 00000000..ef73f249 --- /dev/null +++ b/tests/algorithms/testers/tti_distillation_inplace_perp.py @@ -0,0 +1,56 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.distillation_perp import TextToImageInPlacePERPDistillation +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import replace_datamodule_with_distillation_datamodule, restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +@pytest.mark.slow +class TestTTIDistillationInPlacePerp(AlgorithmTesterBase): + """Test the TTI Distillation InPlace Perp recovery algorithm.""" + + models = ["flux_tiny_random", "sd_tiny_random", "sana_tiny_random"] + reject_models = ["opt_125m"] + metrics = ["cmmd"] + allow_pickle_files = True + algorithm_class = TextToImageInPlacePERPDistillation + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) + + def execute_smash(self, model: Any, smash_config: SmashConfig) -> Any: + """Execute the smash.""" + if any("distillation" in algorithm for algorithm in smash_config.get_active_algorithms()): + self.replaced_datamodule = smash_config.data + replace_datamodule_with_distillation_datamodule(smash_config, model) + smashed_model = super().execute_smash(model, smash_config) + if any("distillation" in algorithm for algorithm in smash_config.get_active_algorithms()): + smash_config.add_data(self.replaced_datamodule) + self.replaced_datamodule = None + return smashed_model diff --git a/tests/algorithms/testers/tti_distillation_lora.py b/tests/algorithms/testers/tti_distillation_lora.py new file mode 100644 index 00000000..ef96d593 --- /dev/null +++ b/tests/algorithms/testers/tti_distillation_lora.py @@ -0,0 +1,56 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.distillation_perp import TextToImageLoraDistillation +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import replace_datamodule_with_distillation_datamodule, restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +@pytest.mark.slow +class TestTTIDistillationLoRA(AlgorithmTesterBase): + """Test the TTI Distillation LoRA recovery algorithm.""" + + models = ["flux_tiny_random", "sd_tiny_random", "sana_tiny_random"] + reject_models = ["opt_tiny_random"] + metrics = ["cmmd"] + allow_pickle_files = True + algorithm_class = TextToImageLoraDistillation + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) + + def execute_smash(self, model: Any, smash_config: SmashConfig) -> Any: + """Execute the smash.""" + if any("distillation" in algorithm for algorithm in smash_config.get_active_algorithms()): + self.replaced_datamodule = smash_config.data + replace_datamodule_with_distillation_datamodule(smash_config, model) + smashed_model = super().execute_smash(model, smash_config) + if any("distillation" in algorithm for algorithm in smash_config.get_active_algorithms()): + smash_config.add_data(self.replaced_datamodule) + self.replaced_datamodule = None + return smashed_model diff --git a/tests/algorithms/testers/tti_distillation_perp.py b/tests/algorithms/testers/tti_distillation_perp.py new file mode 100644 index 00000000..839591ae --- /dev/null +++ b/tests/algorithms/testers/tti_distillation_perp.py @@ -0,0 +1,56 @@ +from typing import Any + +import pytest +import torch +from pruna import SmashConfig +from pruna.engine.utils import get_nn_modules + +from pruna.algorithms.distillation_perp import TextToImagePERPDistillation +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase +from .utils import replace_datamodule_with_distillation_datamodule, restrict_recovery_time + + +def assert_no_nan_values(module: Any) -> None: + """Check for NaN values in the module or its components. + + Parameters + ---------- + module : Any + The module to check. + """ + for nn_module in get_nn_modules(module).values(): + for name, param in nn_module.named_parameters(): + assert not torch.isnan(param).any(), f"NaN values found in {name}" + + +@pytest.mark.slow +class TestTTIDistillationPerp(AlgorithmTesterBase): + """Test the TTI Distillation Perp recovery algorithm.""" + + models = ["flux_tiny_random", "sd_tiny_random", "sana_tiny_random"] + reject_models = ["opt_tiny_random"] + metrics = ["cmmd"] + allow_pickle_files = True + algorithm_class = TextToImagePERPDistillation + + def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: + """Prepare the smash config for the test.""" + super().prepare_smash_config(smash_config, device) + restrict_recovery_time(smash_config, self.algorithm_class.algorithm_name) + + def post_smash_hook(self, model: PrunaModel) -> None: + """Fast hook to verify algorithm application after smashing.""" + assert_no_nan_values(model) + + def execute_smash(self, model: Any, smash_config: SmashConfig) -> Any: + """Execute the smash.""" + if any("distillation" in algorithm for algorithm in smash_config.get_active_algorithms()): + self.replaced_datamodule = smash_config.data + replace_datamodule_with_distillation_datamodule(smash_config, model) + smashed_model = super().execute_smash(model, smash_config) + if any("distillation" in algorithm for algorithm in smash_config.get_active_algorithms()): + smash_config.add_data(self.replaced_datamodule) + self.replaced_datamodule = None + return smashed_model diff --git a/tests/algorithms/testers/utils.py b/tests/algorithms/testers/utils.py index 322dcc7d..13b4cc56 100644 --- a/tests/algorithms/testers/utils.py +++ b/tests/algorithms/testers/utils.py @@ -1,6 +1,8 @@ +import os from typing import Any from pruna.config.smash_config import SmashConfig +from pruna.data.diffuser_distillation_data_module import DiffusionDistillationDataModule def restrict_recovery_time(smash_config: SmashConfig, algorithm_name: str) -> None: @@ -11,6 +13,18 @@ def restrict_recovery_time(smash_config: SmashConfig, algorithm_name: str) -> No smash_config.data.limit_datasets((2, 1, 1)) # 2 train, 1 val, 1 test +def replace_datamodule_with_distillation_datamodule(smash_config: SmashConfig, model: Any) -> None: + """Create a distillation datamodule from the model and replace the datamodule in the smash config.""" + cache_dir = os.path.join(smash_config.cache_dir, f"{model.__class__.__name__.lower()}_distillation") + distillation_data = DiffusionDistillationDataModule( + pipeline=model, + caption_datamodule=smash_config.data, + save_path=cache_dir, + seed=0, + ) + smash_config.add_data(distillation_data) + + def get_model_sparsity(model: Any) -> float: """Get the sparsity of the model.""" total_params = 0 From 0cd3df944166eaafa912aa085457712f42e7ae77 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 9 Jan 2026 16:35:22 +0000 Subject: [PATCH 3/4] fixed linting errors --- tests/algorithms/testers/tti_distillation_inplace_perp.py | 4 ++-- tests/algorithms/testers/tti_distillation_lora.py | 4 ++-- tests/algorithms/testers/tti_distillation_perp.py | 4 ++-- tests/algorithms/testers/tti_inplace_perp.py | 4 ++-- tests/algorithms/testers/tti_lora.py | 4 ++-- tests/algorithms/testers/tti_perp.py | 4 ++-- tests/algorithms/testers/ttt_inplace_perp.py | 4 ++-- tests/algorithms/testers/ttt_perp.py | 4 ++-- tests/algorithms/testers/utils.py | 4 ++-- 9 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/algorithms/testers/tti_distillation_inplace_perp.py b/tests/algorithms/testers/tti_distillation_inplace_perp.py index ef73f249..b8f4b3e9 100644 --- a/tests/algorithms/testers/tti_distillation_inplace_perp.py +++ b/tests/algorithms/testers/tti_distillation_inplace_perp.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.distillation_perp import TextToImageInPlacePERPDistillation from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import replace_datamodule_with_distillation_datamodule, restrict_recovery_time diff --git a/tests/algorithms/testers/tti_distillation_lora.py b/tests/algorithms/testers/tti_distillation_lora.py index ef96d593..8a2a2816 100644 --- a/tests/algorithms/testers/tti_distillation_lora.py +++ b/tests/algorithms/testers/tti_distillation_lora.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.distillation_perp import TextToImageLoraDistillation from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import replace_datamodule_with_distillation_datamodule, restrict_recovery_time diff --git a/tests/algorithms/testers/tti_distillation_perp.py b/tests/algorithms/testers/tti_distillation_perp.py index 839591ae..27536881 100644 --- a/tests/algorithms/testers/tti_distillation_perp.py +++ b/tests/algorithms/testers/tti_distillation_perp.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.distillation_perp import TextToImagePERPDistillation from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import replace_datamodule_with_distillation_datamodule, restrict_recovery_time diff --git a/tests/algorithms/testers/tti_inplace_perp.py b/tests/algorithms/testers/tti_inplace_perp.py index 9fdbb27a..d122f633 100644 --- a/tests/algorithms/testers/tti_inplace_perp.py +++ b/tests/algorithms/testers/tti_inplace_perp.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.perp import TextToImageInPlacePERP from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import restrict_recovery_time diff --git a/tests/algorithms/testers/tti_lora.py b/tests/algorithms/testers/tti_lora.py index 78b52a6a..a15f35fb 100644 --- a/tests/algorithms/testers/tti_lora.py +++ b/tests/algorithms/testers/tti_lora.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.perp import TextToImageLoRA from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import restrict_recovery_time diff --git a/tests/algorithms/testers/tti_perp.py b/tests/algorithms/testers/tti_perp.py index 4cbedf67..e3c5ad9e 100644 --- a/tests/algorithms/testers/tti_perp.py +++ b/tests/algorithms/testers/tti_perp.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.perp import TextToImagePERP from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import restrict_recovery_time diff --git a/tests/algorithms/testers/ttt_inplace_perp.py b/tests/algorithms/testers/ttt_inplace_perp.py index 4ab3f96b..81ecc43b 100644 --- a/tests/algorithms/testers/ttt_inplace_perp.py +++ b/tests/algorithms/testers/ttt_inplace_perp.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.perp import TextToTextInPlacePERP from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import restrict_recovery_time diff --git a/tests/algorithms/testers/ttt_perp.py b/tests/algorithms/testers/ttt_perp.py index e1f857d7..8b654e6e 100644 --- a/tests/algorithms/testers/ttt_perp.py +++ b/tests/algorithms/testers/ttt_perp.py @@ -2,11 +2,11 @@ import pytest import torch -from pruna import SmashConfig -from pruna.engine.utils import get_nn_modules +from pruna import SmashConfig from pruna.algorithms.perp import TextToTextPERP from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import get_nn_modules from .base_tester import AlgorithmTesterBase from .utils import restrict_recovery_time diff --git a/tests/algorithms/testers/utils.py b/tests/algorithms/testers/utils.py index 13b4cc56..1bf6ad50 100644 --- a/tests/algorithms/testers/utils.py +++ b/tests/algorithms/testers/utils.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from typing import Any from pruna.config.smash_config import SmashConfig @@ -15,7 +15,7 @@ def restrict_recovery_time(smash_config: SmashConfig, algorithm_name: str) -> No def replace_datamodule_with_distillation_datamodule(smash_config: SmashConfig, model: Any) -> None: """Create a distillation datamodule from the model and replace the datamodule in the smash config.""" - cache_dir = os.path.join(smash_config.cache_dir, f"{model.__class__.__name__.lower()}_distillation") + cache_dir = Path(smash_config.cache_dir) / f"{model.__class__.__name__.lower()}_distillation" distillation_data = DiffusionDistillationDataModule( pipeline=model, caption_datamodule=smash_config.data, From 1697f2dac3dd0edba850182990173fbbc96f57d8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 14 Jan 2026 10:40:29 +0000 Subject: [PATCH 4/4] fixing typo --- src/pruna/algorithms/distillation_perp.py | 6 +++--- src/pruna/algorithms/perp.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pruna/algorithms/distillation_perp.py b/src/pruna/algorithms/distillation_perp.py index 6f81d8fc..e6a74262 100644 --- a/src/pruna/algorithms/distillation_perp.py +++ b/src/pruna/algorithms/distillation_perp.py @@ -17,7 +17,7 @@ from typing import Iterable from pruna.algorithms.base.tags import AlgorithmTag -from pruna.algorithms.perp import PERPRecoverer +from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer class TextToImagePERPDistillation(PERPRecoverer): @@ -38,8 +38,8 @@ class TextToImagePERPDistillation(PERPRecoverer): group_tags: list[AlgorithmTag] = [AlgorithmTag.DISTILLER, AlgorithmTag.RECOVERER] # type: ignore[attr-defined] algorithm_name = "text_to_image_distillation_perp" tokenizer_required = False - compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache", "flux_caching"] - compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile", "x_fast"] + compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache"] + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"] runs_on: list[str] = ["cuda"] def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None: diff --git a/src/pruna/algorithms/perp.py b/src/pruna/algorithms/perp.py index 617bd8c0..d55cb5c5 100644 --- a/src/pruna/algorithms/perp.py +++ b/src/pruna/algorithms/perp.py @@ -37,8 +37,8 @@ class TextToImagePERP(PERPRecoverer): algorithm_name: str = "text_to_image_perp" tokenizer_required: bool = False - compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache", "flux_caching"] - compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile", "x_fast"] + compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache"] + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"] runs_on: list[str] = ["cuda"] def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None: @@ -90,7 +90,7 @@ class TextToTextPERP(PERPRecoverer): algorithm_name: str = "text_to_text_perp" tokenizer_required: bool = True compatible_before: Iterable[str | AlgorithmTag] = ["half", "quanto", "torch_dynamic"] - compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile", "x_fast"] + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"] def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None: super().__init__(task_name="text_to_text", use_lora=use_lora, use_in_place=use_in_place, is_distillation=False)