1515from loguru import logger
1616from pydantic import ConfigDict , PrivateAttr , model_validator
1717from torch .nn import Module
18- from torch .utils ._pytree import tree_flatten
18+ from torch .utils ._pytree import tree_leaves
1919from tqdm import tqdm
2020
2121from llmcompressor .core import Event , EventType , State
3232from llmcompressor .utils .fsdp .helpers import get_fsdp_parent
3333from llmcompressor .utils .helpers import calibration_forward_context
3434from llmcompressor .utils .pytorch .module import (
35- get_module_to_name_dict ,
35+ get_module_to_name_dict , get_layer_by_name
3636)
3737
3838__all__ = ["AWQModifier" ]
@@ -329,17 +329,18 @@ def _set_resolved_mappings(self, model: Module) -> None:
329329 for smooth_layers , * nested_balance_layers in match_modules_set (
330330 model , (mapping .smooth_layer , * mapping .balance_layers ), self .ignore
331331 ):
332- assert len (smooth_layers ) == 1 , (
333- "AWQ mappings need to match a single smoothlayer for each "
334- f"mapping but got { [module_to_name .get (s ) for s in smooth_layers ]} "
335- f" for mapping: { mapping } "
336- )
332+ if len (smooth_layers )> 1 :
333+ raise ValueError (
334+ "AWQ needs to match a single smoothlayer for each mapping but "
335+ f"got { [module_to_name .get (s ) for s in smooth_layers ]} "
336+ f" for mapping: { mapping } "
337+ )
337338 smooth_layer = smooth_layers [0 ]
338339 smooth_name = module_to_name .get (smooth_layer )
339340
340341 # [[b00, b01, b02...], [b10, b11, b12,...], ...] v
341342 # [b00, b01, b02, ..., b10, b11, b12, ...]
342- balance_layers = tree_flatten (nested_balance_layers )[ 0 ]
343+ balance_layers = tree_leaves (nested_balance_layers )
343344 balance_names = [
344345 module_to_name .get (balance_layer )
345346 for balance_layer in balance_layers
@@ -351,7 +352,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
351352
352353 # skip mapping if any of the balance layers are incompatible
353354 if not all_compatible or len (balance_layers ) == 0 :
354- logger .info (
355+ logger .warning (
355356 f"skipping AWQ for { smooth_name } for mapping { mapping } "
356357 + (
357358 " because found incompatible balance layers"
@@ -362,13 +363,9 @@ def _set_resolved_mappings(self, model: Module) -> None:
362363
363364 continue
364365
365- ancestor_name = get_lowest_common_ancestor_name (balance_names )
366- # no ModuleList ancestors
367- while not isinstance (
368- (ancestor := model .get_submodule (ancestor_name )),
369- torch .nn .ModuleList ,
370- ):
371- ancestor_name = ancestor_name .rsplit ("." , 1 )[0 ]
366+ ancestor_name , ancestor = get_lowest_ancestor_with_avoid (
367+ balance_names , model , torch .nn .ModuleList
368+ )
372369
373370 resolved_mappings .append (
374371 ResolvedMapping (
@@ -741,6 +738,23 @@ def _check_layers_are_compatible(
741738 return False
742739 return True
743740
741+ def get_lowest_ancestor_with_avoid (name : str , model : Module , avoid = torch .nn .Module ):
742+ """
743+ get lowest ancestor that is not the avoided class/type
744+
745+ NOTE: primarily used to exclude parents of type ModuleList, which don't play
746+ nicely with hooks because their forward method is never directly
747+ called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
748+ are selected based on router output and their forward method is called.
749+ https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
750+ """
751+ while True :
752+ if name == "" :
753+ return "" , model
754+ ancestor = get_layer_by_name (name , model )
755+ if not isinstance (ancestor , avoid ):
756+ return name , ancestor
757+ name = "." .join (name .split ("." )[:- 1 ])
744758
745759def _pseudo_quantize_tensor (
746760 w : torch .Tensor , symmetric : bool = False , bit_width : int = 8 , group_size : int = - 1
0 commit comments