@@ -320,7 +320,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
320320 repeat for model.layer.1 and so on
321321 """
322322 resolved_mappings : list [ResolvedMapping ] = []
323-
323+
324324 module_to_name = {}
325325 for name , module in model .named_modules ():
326326 if module in module_to_name :
@@ -331,14 +331,11 @@ def _set_resolved_mappings(self, model: Module) -> None:
331331 )
332332 module_to_name [module ] = name
333333
334-
335-
336334 for mapping in self .mappings :
337-
338335 target_patterns = (mapping .smooth_layer , * mapping .balance_layers )
339336
340- for smooth_layer , * balance_layers in (
341- match_modules_set ( model , target_patterns , self .ignore )
337+ for smooth_layer , * balance_layers in match_modules_set (
338+ model , target_patterns , self .ignore
342339 ):
343340 smooth_name = module_to_name .get (smooth_layer )
344341 balance_names = [
@@ -353,10 +350,11 @@ def _set_resolved_mappings(self, model: Module) -> None:
353350 # skip mapping if any of the balance layers are incompatible
354351 if not all_compatible or len (balance_layers ) == 0 :
355352 logger .info (
356- f"skipping AWQ for { smooth_name } for mapping { mapping } " + (
357- " because found incompatible balance layers"
358- if not all_compatible else
359- f" because no balance layers were found"
353+ f"skipping AWQ for { smooth_name } for mapping { mapping } "
354+ + (
355+ " because found incompatible balance layers"
356+ if not all_compatible
357+ else " because no balance layers were found"
360358 )
361359 )
362360
@@ -812,7 +810,7 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod
812810 Implementation is a small alteration of os.path.commonprefix
813811 https://docs.python.org/3/library/os.path.html#os.path.commonprefix
814812 """
815- # adding "." before and after allows for handling a lot of corner
813+ # adding "." before and after allows for handling a lot of corner
816814 # cases which were previously mishandled ([case]->prefix->result)
817815 # case 0: single module: [.abc.] -> .abc. -> abc
818816 # case 1: substring modules: [.abc., .ab.] -> .ab -> ""
@@ -829,9 +827,9 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod
829827
830828 # 2) throw away module name fragment and leading dot
831829 # ".keep.thro" -> "keep"
832- parent_name = parent_name [1 : parent_name .rfind ("." )]
830+ parent_name = parent_name [1 : parent_name .rfind ("." )]
833831
834- # 3) return first parent that is not a module list
832+ # 3) return first common module that is not a module list
835833 while True :
836834 if parent_name == "" :
837835 return "" , module
0 commit comments