Skip to content
179 changes: 92 additions & 87 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from compressed_tensors.utils import (
align_modules,
get_execution_device,
match_modules_set,
match_named_modules,
update_offload_parameter,
)
from loguru import logger
from pydantic import ConfigDict, PrivateAttr, model_validator
from torch.nn import Module
from torch.utils._pytree import tree_leaves
from tqdm import tqdm

from llmcompressor.core import Event, EventType, State
Expand All @@ -28,7 +30,10 @@
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import get_layer_by_name
from llmcompressor.utils.pytorch.module import (
get_layer_by_name,
get_module_to_name_dict,
)

__all__ = ["AWQModifier"]

Expand Down Expand Up @@ -319,73 +324,57 @@ def _set_resolved_mappings(self, model: Module) -> None:
repeat for model.layer.1 and so on
"""
resolved_mappings: list[ResolvedMapping] = []
for mapping_idx, mapping in enumerate(self.mappings):
num_skipped_mappings = 0

for smooth_name, smooth_layer in (
pbar := tqdm(
match_named_modules(model, [mapping.smooth_layer], self.ignore)
)
module_to_name = get_module_to_name_dict(model)
for mapping in self.mappings:
for smooth_layers, *nested_balance_layers in match_modules_set(
model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore
):
pbar.set_description(
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
f" ({num_skipped_mappings} skipped)"
if len(smooth_layers) > 1:
raise ValueError(
"AWQ needs to match a single smoothlayer for each mapping but "
f"got {[module_to_name.get(s) for s in smooth_layers]}"
f" for mapping: {mapping}"
)
smooth_layer = smooth_layers[0]
smooth_name = module_to_name.get(smooth_layer)

# [[b00, b01, b02...], [b10, b11, b12,...], ...] v
# [b00, b01, b02, ..., b10, b11, b12, ...]
balance_layers = tree_leaves(nested_balance_layers)
balance_names = [
module_to_name.get(balance_layer)
for balance_layer in balance_layers
]

all_compatible = _check_layers_are_compatible(
smooth_layer, smooth_name, balance_layers, balance_names
)

smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)

balance_layers, balance_names = [], []
for balance_regex in mapping.balance_layers:
# find the submodules that match the activation layer
for balance_suffix, balance_layer in match_named_modules(
smooth_parent, [balance_regex], self.ignore
):
balance_name = f"{smooth_parent_name}.{balance_suffix}"

# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
if (
isinstance(smooth_layer, torch.nn.Linear)
and isinstance(balance_layer, torch.nn.Linear)
and balance_name.endswith(".o_proj")
and (
(
smooth_name.endswith(".v_proj")
and smooth_layer.out_features
!= balance_layer.in_features
)
or (
smooth_name.endswith(".qkv_proj")
and smooth_layer.out_features
!= 3 * balance_layer.in_features
)
)
):
num_skipped_mappings += 1
continue

balance_layers.append(balance_layer)
balance_names.append(balance_name)
# skip mapping if any of the balance layers are incompatible
if not all_compatible or len(balance_layers) == 0:
logger.warning(
f"skipping AWQ for {smooth_name} for mapping {mapping}"
+ (
" because found incompatible balance layers"
if not all_compatible
else " because no balance layers were found"
)
)

if len(balance_layers) == 0:
continue

elif len(balance_layers) == 1:
# for single balance layer, parent is the balance layer
parent_name, parent = balance_name, balance_layer
else:
# for multiple balance layers, find lowest common parent
parent_name, parent = get_lowest_common_parent(balance_names, model)
ancestor_name, ancestor = get_lowest_ancestor_with_avoid(
balance_names, model, torch.nn.ModuleList
)

resolved_mappings.append(
ResolvedMapping(
smooth_name,
smooth_layer,
balance_layers,
balance_names=balance_names,
parent=parent,
parent_name=parent_name,
parent=ancestor,
parent_name=ancestor_name,
)
)
self._resolved_mappings = resolved_mappings
Expand Down Expand Up @@ -721,6 +710,54 @@ def _assert_all_activations_consumed(self):
raise RuntimeError("Some cached activations were not used")


def _check_layers_are_compatible(
smooth_layer, smooth_name, balance_layers, balance_names
):
"""
returns True if they are all compatible
returns False if any smooth & balance layers are incompatible
"""
for balance_layer, balance_name in zip(balance_layers, balance_names):
# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
if (
isinstance(smooth_layer, torch.nn.Linear)
and isinstance(balance_layer, torch.nn.Linear)
and balance_name.endswith(".o_proj")
and (
(
smooth_name.endswith(".v_proj")
and smooth_layer.out_features != balance_layer.in_features
)
or (
smooth_name.endswith(".qkv_proj")
and smooth_layer.out_features != 3 * balance_layer.in_features
)
)
):
return False
return True


def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default for avoid seems risky to me. Maybe you meant it to be torch.nn.ModuleList? Otherwise this will probably just return the root model on every pytorch model unless avoid is explicitly set.

"""
get lowest ancestor that is not the avoided class/type

NOTE: primarily used to exclude parents of type ModuleList, which don't play
nicely with hooks because their forward method is never directly
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
are selected based on router output and their forward method is called.
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
"""
while True:
if name == "":
return "", model
ancestor = get_layer_by_name(name, model)
if not isinstance(ancestor, avoid):
return name, ancestor
name = ".".join(name.split(".")[:-1])


def _pseudo_quantize_tensor(
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
Expand Down Expand Up @@ -779,35 +816,3 @@ def _accumulate_mean(
new_count = prev_count + num_added

return (prev_sum + sum_added) / new_count, new_count


def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
"""
Given a list of names, returns the lowest-scope common parent.

NOTE: function excludes parents of type ModuleList, which don't play
nicely with hooks because their forward method is never directly
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
are selected based on router output and their forward method is called.
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233

Returns name of parent and pointer to parent module

Implementation is a small alteration of os.path.commonprefix
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
"""
s1 = min(names)
s2 = max(names)
parent_name = ""
for i, c in enumerate(s1):
if c != s2[i]:
parent_name = s1[:i].rstrip(".")
break

while True:
if parent_name == "":
return "", module
parent = get_layer_by_name(parent_name, module)
if not isinstance(parent, torch.nn.ModuleList):
return parent_name, parent
parent_name = ".".join(parent_name.split(".")[:-1])
43 changes: 20 additions & 23 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors.utils import align_module_device, match_named_modules
from compressed_tensors.utils import align_module_device, match_modules_set
from loguru import logger
from pydantic import ConfigDict, Field
from torch.nn import Module
from torch.utils._pytree import tree_leaves

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
Expand All @@ -14,7 +15,7 @@
handle_mapping_resolution_errors,
)
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import get_layer_by_name
from llmcompressor.utils.pytorch.module import get_module_to_name_dict

MINIMUM_SMOOTHING_SCALE = 1e-5

Expand Down Expand Up @@ -95,7 +96,7 @@ class SmoothQuantModifier(Modifier):
"""

smoothing_strength: float = 0.5
mappings: Optional[List[Union[Tuple, List]]] = None
mappings: Optional[List[Tuple[List[str], str]]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None
Expand Down Expand Up @@ -198,27 +199,23 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
be balanced.
"""
resolved_mappings = []
for to_balance, to_smooth in self.mappings:
to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth

for smooth_name, smooth_layer in match_named_modules(
model, to_smooth_list, self.ignore
module_to_name = get_module_to_name_dict(model)
for mapping in self.mappings:
for *nested_balance_layers, smooth_layers in match_modules_set(
model, tree_leaves(mapping), self.ignore
):
# Search for balance layers within the parent scope
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)

balance_layers = [
balance_layer
for _, balance_layer in match_named_modules(
smooth_parent, to_balance, self.ignore
)
]

if balance_layers:
resolved_mappings.append(
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
if len(smooth_layers) > 1:
raise ValueError(
"SmoothQuant must match a single smooth layer for each mapping"
f" but got {[module_to_name.get(s) for s in smooth_layers]}"
f" for mapping: {mapping}"
)
smooth_layer = smooth_layers[0]
smooth_name = module_to_name.get(smooth_layers[0])
balance_layers = tree_leaves(nested_balance_layers)
resolved_mappings.append(
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
)

return resolved_mappings

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/modifiers/smoothquant/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from collections import namedtuple
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple

from loguru import logger

Expand All @@ -10,7 +10,7 @@
"DEFAULT_SMOOTHQUANT_MAPPINGS",
]

LayerMapType = Tuple[Union[List[str], str], Union[List[str], str]]
LayerMapType = Tuple[List[str], str]
LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"])

DEFAULT_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [
Expand Down
5 changes: 4 additions & 1 deletion src/llmcompressor/modifiers/transform/spinquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from compressed_tensors.utils import TorchDtype, get_head_dim
from pydantic import Field, ValidationInfo, field_validator
from torch.utils._pytree import tree_leaves
from transformers import PreTrainedModel

from llmcompressor.core import Event, EventType, State
Expand Down Expand Up @@ -205,7 +206,9 @@ def _fuse_norms(self, model: PreTrainedModel):
for norm, *linears in match_modules_set(
model, (mapping.norm, *mapping.linears)
):
fuse_norm_linears(norm, linears)
# match_modules_set returns a list of lists
assert len(norm) == 1
fuse_norm_linears(norm[0], tree_leaves(linears))

def _create_r1_scheme(self) -> TransformScheme:
return TransformScheme(
Expand Down
13 changes: 13 additions & 0 deletions src/llmcompressor/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from compressed_tensors import InternalModule
from compressed_tensors.quantization.utils import is_module_quantized
from loguru import logger
from torch.nn import Linear, Module, Parameter
from torch.nn.modules.conv import _ConvNd
from transformers import PreTrainedModel
Expand Down Expand Up @@ -369,3 +370,15 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module:
if not layer_name:
return module
return attrgetter(layer_name)(module)


def get_module_to_name_dict(model: Module) -> dict[Module:str]:
module_to_name = {}
for name, module in model.named_modules():
if module in module_to_name:
logger.info(
f"Warning, {name} and {module_to_name[module]} both "
"share the same module, which can result in unexpected behavior"
)
module_to_name[module] = name
return module_to_name
Loading
Loading