Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 72 additions & 28 deletions auto_round/compressors/diffusion/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class DiffusionCompressor(BaseCompressor):
The more it is, the more closely it follows the prompt (default is 7.5).
num_inference_steps (int): The reference number of denoising steps (default is 50).
generator_seed (int): A sees that controls the initial noise from which an image is generated (default is None).
pipeline_fn (callable, optional): Custom callable to run the pipeline during calibration.
Signature: ``fn(pipe, prompts, *, guidance_scale, num_inference_steps, generator, **kwargs)``.
Use this to support models whose inference API differs from the standard convention
(e.g. NextStep). If ``None``, the standard ``pipe(prompts, ...)`` call is used unless
the loaded pipeline already exposes an ``_autoround_pipeline_fn`` attribute.
scheme: (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations.
layer_config (dict): Configuration for weight quantization (default is None).
dataset: The path or name of the calib dataset.
Expand Down Expand Up @@ -102,9 +107,15 @@ def __init__(
device_map: Union[str, torch.device, int, dict] = 0,
enable_torch_compile: bool = False,
seed: int = 42,
pipeline_fn: callable = None,
**kwargs,
):
logger.warning("Diffusion model quantization is experimental and is only validated on Flux models.")
if dataset == "NeelNanda/pile-10k":
dataset = "coco2014"
logger.warning(
"Dataset 'NeelNanda/pile-10k' is not suitable for diffusion model quantization, use coco2014 dataset instead."
)
model_dtype = kwargs.pop("model_dtype", None)

self.guidance_scale = guidance_scale
Expand All @@ -120,6 +131,8 @@ def __init__(

self.model = model
self.pipe = pipe
# Use explicit pipeline_fn; fall back to whatever diffusion_load_model attached to the pipe
self.pipeline_fn = pipeline_fn or getattr(pipe, "_autoround_pipeline_fn", None)

all_blocks = get_block_names(model)
self.quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)
Expand Down Expand Up @@ -278,6 +291,64 @@ def _get_current_num_elm(
current_input_ids = [input_ids["hidden_states"][i] for i in indices]
return sum(id.numel() for id in current_input_ids)

def _run_pipeline(self, prompts: list) -> None:
"""Execute one full diffusion pipeline forward pass for calibration input capture.

This drives all transformer blocks so that their intermediate inputs are recorded
by the hooks installed during calibration.

**Extending for custom models** – choose whichever approach is simpler:

* Pass a ``pipeline_fn`` to the constructor (no subclassing required). The
callable receives ``(pipe, prompts, *, guidance_scale, num_inference_steps,
generator, **kwargs)`` and must trigger a full forward pass.
* Subclass :class:`DiffusionCompressor` and override this method directly for
full control over the inference logic.

Example – NextStep model::

def nextstep_fn(pipe, prompts, guidance_scale=7.5,
num_inference_steps=28, generator=None,
hw=(1024, 1024), **kwargs):
for prompt in (prompts if isinstance(prompts, list) else [prompts]):
pipe.generate_image(
prompt,
cfg=guidance_scale,
num_sampling_steps=num_inference_steps,
hw=hw,
**kwargs,
)

compressor = DiffusionCompressor(
model="path/to/nextstep",
pipeline_fn=nextstep_fn,
pipeline_fn_kwargs={"hw": (512, 512)},
)

Args:
prompts (list[str]): Text prompts for the current calibration batch.
"""
generator = (
None
if self.generator_seed is None
else torch.Generator(device=self.pipe.device).manual_seed(self.generator_seed)
)
if self.pipeline_fn is not None:
self.pipeline_fn(
self.pipe,
prompts,
guidance_scale=self.guidance_scale,
num_inference_steps=self.num_inference_steps,
generator=generator,
)
else:
self.pipe(
prompts,
guidance_scale=self.guidance_scale,
num_inference_steps=self.num_inference_steps,
generator=generator,
)

def calib(self, nsamples, bs):
"""Perform calibration for quantization.

Expand Down Expand Up @@ -308,40 +379,13 @@ def calib(self, nsamples, bs):
total_cnt = 0

total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader))
if self.pipe.dtype != self.model.dtype:
self.pipe.to(self.model.dtype)

if (
hasattr(self.model, "hf_device_map")
and len(self.model.hf_device_map) > 0
and self.pipe.device != self.model.device
and torch.device(self.model.device).type in ["cuda", "xpu"]
):
logger.error(
"Diffusion model is activated sequential model offloading, it will crash during moving to GPU/XPU. "
"Please use model path for quantization or "
"move the pipeline object to GPU/XPU before passing them into API."
)
exit(-1)

if self.pipe.device != self.model.device:
self.pipe.to(self.model.device)
self.pipe.to(self.model.dtype)
with tqdm(range(1, total + 1), desc="cache block inputs") as pbar:
for ids, prompts in self.dataloader:
if isinstance(prompts, tuple):
prompts = list(prompts)
try:
self.pipe(
prompt=prompts,
guidance_scale=self.guidance_scale,
num_inference_steps=self.num_inference_steps,
generator=(
None
if self.generator_seed is None
else torch.Generator(device=self.pipe.device).manual_seed(self.generator_seed)
),
)
self._run_pipeline(prompts)
except NotImplementedError:
pass
except Exception as error:
Expand Down
74 changes: 74 additions & 0 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,46 @@ def diffusion_load_model(
if device_str is not None and "hpu" in device_str:
torch_dtype = torch.bfloat16

try:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
except:
config = None

model_type = getattr(config, "model_type", "")
# A special case for NextStep
if model_type == "nextstep":
from models.gen_pipeline import NextStepPipeline # pylint: disable=E0401
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, local_files_only=True, trust_remote_code=True
)
model = AutoModel.from_pretrained(pretrained_model_name_or_path, local_files_only=True, trust_remote_code=True)
# The model is loaded onto the device because more than one block requires input data.
pipe = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device_str, dtype=torch.bfloat16)

def _nextstep_pipeline_fn(pipe, prompts, guidance_scale=7.5, num_inference_steps=28, generator=None, **kwargs):
"""Default pipeline_fn for NextStep models.

Maps standard :class:`DiffusionCompressor` parameters to NextStep's
``generate_image`` API. Pass a custom ``pipeline_fn`` to
:class:`DiffusionCompressor` to override defaults or supply
model-specific kwargs (e.g. ``hw``, ``positive_prompt``,
``cfg_schedule``, ``timesteps_shift``).
"""
for prompt in (prompts if isinstance(prompts, list) else [prompts]):
pipe.generate_image(
prompt,
cfg=guidance_scale,
num_sampling_steps=num_inference_steps,
**kwargs,
)

pipe._autoround_pipeline_fn = _nextstep_pipeline_fn
return pipe, pipe.model

pipelines = LazyImport("diffusers.pipelines")
if isinstance(pretrained_model_name_or_path, str):
if torch_dtype == "auto":
Expand Down Expand Up @@ -729,6 +769,25 @@ def model_save_pretrained(model, save_directory, **kwargs):

# non-meta model uses model.save_pretrained for model and config saving
setattr(model, "save_pretrained", partial(model_save_pretrained, model))

if pipe.dtype != model.dtype:
pipe.to(model.dtype)
if pipe.device != model.device:
pipe.to(model.device)

if (
hasattr(model, "hf_device_map")
and len(model.hf_device_map) > 0
and pipe.device != model.device
and torch.device(model.device).type in ["cuda", "xpu"]
):
logger.error(
"Diffusion model is activated sequential model offloading, it will crash during moving to GPU/XPU. "
"Please use model path for quantization or "
"move the pipeline object to GPU/XPU before passing them into API."
)
exit(-1)

return pipe, model.to(device)


Expand Down Expand Up @@ -794,6 +853,21 @@ def is_gguf_model(model_path: Union[str, torch.nn.Module]) -> bool:
def is_diffusion_model(model_or_path: Union[str, object]) -> bool:
from auto_round.utils.common import LazyImport

# First check if it's a known diffusion pipeline by config/model_type to avoid unnecessary imports and file checks for non-diffusion models, which can be time-consuming.
try:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
model_type = getattr(config, "model_type", "")
# A special case for NextStep
if model_type == "nextstep":
return True
except:
logger.warning(
f"Failed to load config for {model_or_path}, trying to check model_index.json for diffusion pipeline."
)

# Then check if model_index.json exists for diffusion pipeline, which is a strong signal of being a diffusion pipeline.
if isinstance(model_or_path, str):
index_file = None
if not os.path.isdir(model_or_path):
Expand Down
70 changes: 56 additions & 14 deletions auto_round_extension/cuda/gptqmodel_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def get_marlin_layer(): ##use an ugly wrapper to import gptqmodel on demand
NEW_VERSION = False
if Version(gptqmodel.__version__) >= Version("5.0.0"):
NEW_VERSION = True
NEW_VERSION_6_0 = False
if Version(gptqmodel.__version__) >= Version("6.0.0"):
NEW_VERSION_6_0 = True
from gptqmodel.models._const import DEVICE, PLATFORM # pylint: disable=E0401
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # pylint: disable=E0401
from gptqmodel.utils.backend import BACKEND # pylint: disable=E0401
Expand Down Expand Up @@ -244,20 +247,59 @@ def __init__(
# (since we have only one group per output channel)
desc_act = False

super().__init__(
bits=bits,
group_size=group_size,
sym=sym,
desc_act=desc_act,
in_features=in_features,
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
backend=kwargs.pop("backend", BACKEND.MARLIN),
adapter=None,
register_buffers=False,
**kwargs,
)
backend = kwargs.pop("backend", BACKEND.MARLIN)
if NEW_VERSION_6_0:
# gptqmodel >= 6.0.0: BaseQuantLinear no longer accepts group_size/sym/desc_act/pack_dtype
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help fix other gptqmodel backend issues, such as gptqmodel:exllamav2,gptqmodel:awq,etc.

Since the GPTQModel API is changing frequently, should we consider other repos, such as using vLLM directly?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only this backend is using BaseQuantLinear. Other backend should work as expected.
The vLLM kernel looks promising, but installing vLLM involves a large number of dependencies, which could result in a poor user experience.
vllm/csrc/quantization

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we bypass the GPTQModel API and use our own interface to directly leverage its kernels? I assume the kernel implementations change far less frequently.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#837
Is this one better?

# directly; they must be passed via validate_kwargs. Attributes are also set manually.
super().__init__(
bits=bits,
in_features=in_features,
out_features=out_features,
bias=bias,
backend=backend,
adapter=None,
register_buffers=False,
validate_kwargs={
"group_size": group_size,
"desc_act": desc_act,
"sym": sym,
"pack_dtype": pack_dtype,
},
**kwargs,
)
# Set attributes that intermediate classes (PackedQuantLinear /
# GPTQQuantLinear) would have set in the old API.
self.pack_dtype = pack_dtype
if pack_dtype == torch.int8:
self.pack_dtype_bits = 8
elif pack_dtype == torch.int16:
self.pack_dtype_bits = 16
elif pack_dtype == torch.int32:
self.pack_dtype_bits = 32
elif pack_dtype == torch.int64:
self.pack_dtype_bits = 64
else:
raise ValueError(f"Unsupported pack_dtype: {pack_dtype}")
self.pack_factor = self.pack_dtype_bits // bits
self.group_size = group_size if group_size != -1 else in_features
self.requested_group_size = group_size
self.desc_act = desc_act
self.sym = sym
else:
super().__init__(
bits=bits,
group_size=group_size,
sym=sym,
desc_act=desc_act,
in_features=in_features,
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
backend=backend,
adapter=None,
register_buffers=False,
**kwargs,
)

# toggle fp32 mode depending on MARLIN or MARLIN_FP16 backend
self.fp32 = True if self.backend in [BACKEND.MARLIN, BACKEND.AUTO] else False
Expand Down
Loading