Skip to content

Commit af7d76e

Browse files
committed
fix infinite loop and the asserts
Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent 056daee commit af7d76e

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from loguru import logger
1616
from pydantic import ConfigDict, PrivateAttr, model_validator
1717
from torch.nn import Module
18-
from torch.utils._pytree import tree_flatten
18+
from torch.utils._pytree import tree_leaves
1919
from tqdm import tqdm
2020

2121
from llmcompressor.core import Event, EventType, State
@@ -32,7 +32,7 @@
3232
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3333
from llmcompressor.utils.helpers import calibration_forward_context
3434
from llmcompressor.utils.pytorch.module import (
35-
get_module_to_name_dict,
35+
get_module_to_name_dict, get_layer_by_name
3636
)
3737

3838
__all__ = ["AWQModifier"]
@@ -329,17 +329,18 @@ def _set_resolved_mappings(self, model: Module) -> None:
329329
for smooth_layers, *nested_balance_layers in match_modules_set(
330330
model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore
331331
):
332-
assert len(smooth_layers) == 1, (
333-
"AWQ mappings need to match a single smoothlayer for each "
334-
f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}"
335-
f" for mapping: {mapping}"
336-
)
332+
if len(smooth_layers)>1:
333+
raise ValueError(
334+
"AWQ needs to match a single smoothlayer for each mapping but "
335+
f"got {[module_to_name.get(s) for s in smooth_layers]}"
336+
f" for mapping: {mapping}"
337+
)
337338
smooth_layer = smooth_layers[0]
338339
smooth_name = module_to_name.get(smooth_layer)
339340

340341
# [[b00, b01, b02...], [b10, b11, b12,...], ...] v
341342
# [b00, b01, b02, ..., b10, b11, b12, ...]
342-
balance_layers = tree_flatten(nested_balance_layers)[0]
343+
balance_layers = tree_leaves(nested_balance_layers)
343344
balance_names = [
344345
module_to_name.get(balance_layer)
345346
for balance_layer in balance_layers
@@ -351,7 +352,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
351352

352353
# skip mapping if any of the balance layers are incompatible
353354
if not all_compatible or len(balance_layers) == 0:
354-
logger.info(
355+
logger.warning(
355356
f"skipping AWQ for {smooth_name} for mapping {mapping}"
356357
+ (
357358
" because found incompatible balance layers"
@@ -362,13 +363,9 @@ def _set_resolved_mappings(self, model: Module) -> None:
362363

363364
continue
364365

365-
ancestor_name = get_lowest_common_ancestor_name(balance_names)
366-
# no ModuleList ancestors
367-
while not isinstance(
368-
(ancestor := model.get_submodule(ancestor_name)),
369-
torch.nn.ModuleList,
370-
):
371-
ancestor_name = ancestor_name.rsplit(".", 1)[0]
366+
ancestor_name, ancestor = get_lowest_ancestor_with_avoid(
367+
balance_names, model, torch.nn.ModuleList
368+
)
372369

373370
resolved_mappings.append(
374371
ResolvedMapping(
@@ -741,6 +738,23 @@ def _check_layers_are_compatible(
741738
return False
742739
return True
743740

741+
def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Module):
742+
"""
743+
get lowest ancestor that is not the avoided class/type
744+
745+
NOTE: primarily used to exclude parents of type ModuleList, which don't play
746+
nicely with hooks because their forward method is never directly
747+
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
748+
are selected based on router output and their forward method is called.
749+
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
750+
"""
751+
while True:
752+
if name == "":
753+
return "", model
754+
ancestor = get_layer_by_name(name, model)
755+
if not isinstance(ancestor, avoid):
756+
return name, ancestor
757+
name = ".".join(name.split(".")[:-1])
744758

745759
def _pseudo_quantize_tensor(
746760
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,12 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
204204
for *nested_balance_layers, smooth_layers in match_modules_set(
205205
model, tree_leaves(mapping), self.ignore
206206
):
207-
assert len(smooth_layers) == 1, (
208-
"SmoothQuant mappings must match a single smooth layer for each "
209-
f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}"
210-
f" for mapping: {mapping}"
211-
)
207+
if len(smooth_layers)>1:
208+
raise ValueError (
209+
"SmoothQuant must match a single smooth layer for each mapping"
210+
f" but got {[module_to_name.get(s) for s in smooth_layers]}"
211+
f" for mapping: {mapping}"
212+
)
212213
smooth_layer = smooth_layers[0]
213214
smooth_name = module_to_name.get(smooth_layers[0])
214215
balance_layers = tree_leaves(nested_balance_layers)

0 commit comments

Comments
 (0)