@@ -313,74 +313,43 @@ def _set_resolved_mappings(self, model: Module) -> None:
313313 into ResolvedMapping objects, resolving regular expressions.
314314 Result is stored in _resolved_mappings.
315315
316- Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers)
317- that belong together in the model architecture.
316+ For each activation in the mapping list, we find the corresponding weight to
317+ balance by searching for the longest substring. For instance, if our balance
318+ weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
319+ would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
320+ repeat for model.layer.1 and so on
318321 """
319- # Build a module-to-name mapping for efficient lookups
320- module_to_name = {module : name for name , module in model .named_modules ()}
321-
322322 resolved_mappings : list [ResolvedMapping ] = []
323+ module_to_name = {module : name for name , module in model .named_modules ()}
323324 for mapping_idx , mapping in enumerate (self .mappings ):
324325 num_skipped_mappings = 0
325326
326327 # Use match_modules_set to find coherent sets of modules
327328 target_patterns = (mapping .smooth_layer , * mapping .balance_layers )
328329
329- for modules_set in (
330+ for smooth_layer , * balance_layers in (
330331 pbar := tqdm (match_modules_set (model , target_patterns , self .ignore ))
331332 ):
332333 pbar .set_description (
333334 f"Resolving mapping { mapping_idx + 1 } /{ len (self .mappings )} "
334335 f" ({ num_skipped_mappings } skipped)"
335336 )
336337
337- # Unpack the matched set: first is smooth_layer, rest are balance_layers
338- smooth_layer = modules_set [0 ]
339- all_balance_layers = list (modules_set [1 :])
340-
341- # Get names using the pre-built mapping
342338 smooth_name = module_to_name .get (smooth_layer )
343- if smooth_name is None :
344- continue
339+ balance_names = [
340+ module_to_name .get (balance_layer )
341+ for balance_layer in balance_layers
342+ ]
345343
346- # Filter balance layers, skipping incompatible ones
347- balance_layers = []
348- balance_names = []
349-
350- for balance_layer in all_balance_layers :
351- balance_name = module_to_name .get (balance_layer )
352- if balance_name is None :
353- continue
354-
355- # exclude v_proj->o_proj mappings whose shapes are incompatible
356- # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
357- if (
358- isinstance (smooth_layer , torch .nn .Linear )
359- and isinstance (balance_layer , torch .nn .Linear )
360- and balance_name .endswith (".o_proj" )
361- and (
362- (
363- smooth_name .endswith (".v_proj" )
364- and smooth_layer .out_features
365- != balance_layer .in_features
366- )
367- or (
368- smooth_name .endswith (".qkv_proj" )
369- and smooth_layer .out_features
370- != 3 * balance_layer .in_features
371- )
372- )
373- ):
374- num_skipped_mappings += 1
375- continue
376-
377- balance_layers .append (balance_layer )
378- balance_names .append (balance_name )
344+ all_compatible = _check_layers_are_compatible (
345+ smooth_layer , smooth_name , balance_layers , balance_names
346+ )
379347
380- if len (balance_layers ) == 0 :
348+ # skip mapping if any of the balance layers are incompatible
349+ if not all_compatible or len (balance_layers ) == 0 :
350+ num_skipped_mappings += 1
381351 continue
382-
383- if len (balance_layers ) == 1 :
352+ elif len (balance_layers ) == 1 :
384353 # for single balance layer, parent is the balance layer
385354 parent_name , parent = balance_names [0 ], balance_layers [0 ]
386355 else :
@@ -730,6 +699,35 @@ def _assert_all_activations_consumed(self):
730699 raise RuntimeError ("Some cached activations were not used" )
731700
732701
702+ def _check_layers_are_compatible (
703+ smooth_layer , smooth_name , balance_layers , balance_names
704+ ):
705+ """
706+ returns True if they are all compatible
707+ returns False if any smooth & balance layers are incompatible
708+ """
709+ for balance_layer , balance_name in zip (balance_layers , balance_names ):
710+ # exclude v_proj->o_proj mappings whose shapes are incompatible
711+ # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
712+ if (
713+ isinstance (smooth_layer , torch .nn .Linear )
714+ and isinstance (balance_layer , torch .nn .Linear )
715+ and balance_name .endswith (".o_proj" )
716+ and (
717+ (
718+ smooth_name .endswith (".v_proj" )
719+ and smooth_layer .out_features != balance_layer .in_features
720+ )
721+ or (
722+ smooth_name .endswith (".qkv_proj" )
723+ and smooth_layer .out_features != 3 * balance_layer .in_features
724+ )
725+ )
726+ ):
727+ return False
728+ return True
729+
730+
733731def _pseudo_quantize_tensor (
734732 w : torch .Tensor , symmetric : bool = False , bit_width : int = 8 , group_size : int = - 1
735733):
0 commit comments