1010 match_modules_set ,
1111 match_named_modules ,
1212 update_offload_parameter ,
13+ get_lowest_common_ancestor_name ,
1314)
1415from loguru import logger
1516from pydantic import ConfigDict , PrivateAttr , model_validator
1617from torch .nn import Module
1718from tqdm import tqdm
18-
19+ from torch . utils . _pytree import tree_flatten
1920from llmcompressor .core import Event , EventType , State
2021from llmcompressor .modifiers import Modifier
2122from llmcompressor .modifiers .awq .mappings import (
@@ -332,12 +333,20 @@ def _set_resolved_mappings(self, model: Module) -> None:
332333 module_to_name [module ] = name
333334
334335 for mapping in self .mappings :
335- target_patterns = (mapping .smooth_layer , * mapping .balance_layers )
336-
337- for smooth_layer , * balance_layers in match_modules_set (
338- model , target_patterns , self .ignore
336+ for smooth_layers , * nested_balance_layers in match_modules_set (
337+ model , (mapping .smooth_layer , * mapping .balance_layers ), self .ignore
339338 ):
339+ assert len (smooth_layers )== 1 , (
340+ "AWQ mappings need to match a single smoothlayer for each mapping but got "
341+ f"{ [module_to_name .get (smooth_layer ) for smooth_layer in smooth_layers ]} "
342+ f"when matching { mapping .smooth_layer } "
343+ )
344+ smooth_layer = smooth_layers [0 ]
340345 smooth_name = module_to_name .get (smooth_layer )
346+
347+ #[[b00, b01, b02...], [b10, b11, b12,...], ...] v
348+ # [b00, b01, b02, ..., b10, b11, b12, ...]
349+ balance_layers = tree_flatten (nested_balance_layers )[0 ]
341350 balance_names = [
342351 module_to_name .get (balance_layer )
343352 for balance_layer in balance_layers
@@ -361,16 +370,17 @@ def _set_resolved_mappings(self, model: Module) -> None:
361370 continue
362371 else :
363372 # for multiple balance layers, find lowest common parent
364- parent_name , parent = get_lowest_common_module (balance_names , model )
373+ ancestor_name = get_lowest_common_ancestor_name (balance_names )
374+ ancestor , ancestor_name = get_lowest_non_module_list_ancestor (ancestor_name , )
365375
366376 resolved_mappings .append (
367377 ResolvedMapping (
368378 smooth_name ,
369379 smooth_layer ,
370380 balance_layers ,
371381 balance_names = balance_names ,
372- parent = parent ,
373- parent_name = parent_name ,
382+ parent = ancestor ,
383+ parent_name = ancestor_name ,
374384 )
375385 )
376386 self ._resolved_mappings = resolved_mappings
@@ -795,45 +805,25 @@ def _accumulate_mean(
795805 return (prev_sum + sum_added ) / new_count , new_count
796806
797807
798- def get_lowest_common_module ( names : list [ str ] , module : Module ) -> tuple [str , Module ]:
808+ def get_lowest_non_module_list_ancestor ( name , module : Module ) -> tuple [str , Module ]:
799809 """
800- Given a list of names, returns the lowest-scope common module.
810+ Given a name: foo.bar.baz, finds lowest ancestor that's not a ModuleList
811+ i.e. module_list.module_dict.module_list -> module_list.module_dict
812+ i.e. module_list.module_dict -> module_list.module_dict
813+ (self is an ancestor of self)
801814
802- NOTE: function excludes modules of type ModuleList, which don't play
815+ NOTE: This is needed because ModuleLists don't play
803816 nicely with hooks because their forward method is never directly
804817 called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
805818 are selected based on router output and their forward method is called.
806819 https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
807820
808821 Returns name of module and pointer to module
809-
810- Implementation is a small alteration of os.path.commonprefix
811- https://docs.python.org/3/library/os.path.html#os.path.commonprefix
812822 """
813- # adding "." before and after allows for handling a lot of corner
814- # cases which were previously mishandled ([case]->prefix->result)
815- # case 0: single module: [.abc.] -> .abc. -> abc
816- # case 1: substring modules: [.abc., .ab.] -> .ab -> ""
817- # case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab
818- s1 = min (names ) + "."
819- s2 = max (names ) + "."
820-
821- # 1) find longest shared prefix
822- parent_name = "."
823- for i , c in enumerate (s1 ):
824- if c != s2 [i ]:
825- break
826- parent_name += c
827-
828- # 2) throw away module name fragment and leading dot
829- # ".keep.thro" -> "keep"
830- parent_name = parent_name [1 : parent_name .rfind ("." )]
831-
832- # 3) return first common module that is not a module list
833823 while True :
834- if parent_name == "" :
824+ if name == "" :
835825 return "" , module
836- parent = get_layer_by_name (parent_name , module )
837- if not isinstance (parent , torch .nn .ModuleList ):
838- return parent_name , parent
839- parent_name = "." .join (parent_name .split ("." )[:- 1 ])
826+ module = get_layer_by_name (name , module )
827+ if not isinstance (module , torch .nn .ModuleList ):
828+ return name , module
829+ name = "." .join (parent_name .split ("." )[:- 1 ])
0 commit comments