Skip to content

Commit 69cf95b

Browse files
committed
[AWQ] small refactor to use match_modules_set
Summary: modified _set_resolved_mappings to get smoothing and balance layers at same time. Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent db0b68d commit 69cf95b

File tree

1 file changed

+52
-41
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+52
-41
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from compressed_tensors.utils import (
88
align_modules,
99
get_execution_device,
10+
match_modules_set,
1011
match_named_modules,
1112
update_offload_parameter,
1213
)
@@ -312,68 +313,78 @@ def _set_resolved_mappings(self, model: Module) -> None:
312313
into ResolvedMapping objects, resolving regular expressions.
313314
Result is stored in _resolved_mappings.
314315
315-
For each activation in the mapping list, we find the corresponding weight to
316-
balance by searching for the longest substring. For instance, if our balance
317-
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
318-
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
319-
repeat for model.layer.1 and so on
316+
Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers)
317+
that belong together in the model architecture.
320318
"""
319+
# Build a module-to-name mapping for efficient lookups
320+
module_to_name = {module: name for name, module in model.named_modules()}
321+
321322
resolved_mappings: list[ResolvedMapping] = []
322323
for mapping_idx, mapping in enumerate(self.mappings):
323324
num_skipped_mappings = 0
324325

325-
for smooth_name, smooth_layer in (
326+
# Use match_modules_set to find coherent sets of modules
327+
target_patterns = (mapping.smooth_layer, *mapping.balance_layers)
328+
329+
for modules_set in (
326330
pbar := tqdm(
327-
match_named_modules(model, [mapping.smooth_layer], self.ignore)
331+
match_modules_set(model, target_patterns, self.ignore)
328332
)
329333
):
330334
pbar.set_description(
331335
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
332336
f" ({num_skipped_mappings} skipped)"
333337
)
334338

335-
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
336-
smooth_parent = get_layer_by_name(smooth_parent_name, model)
339+
# Unpack the matched set: first is smooth_layer, rest are balance_layers
340+
smooth_layer = modules_set[0]
341+
all_balance_layers = list(modules_set[1:])
337342

338-
balance_layers, balance_names = [], []
339-
for balance_regex in mapping.balance_layers:
340-
# find the submodules that match the activation layer
341-
for balance_suffix, balance_layer in match_named_modules(
342-
smooth_parent, [balance_regex], self.ignore
343-
):
344-
balance_name = f"{smooth_parent_name}.{balance_suffix}"
345-
346-
# exclude v_proj->o_proj mappings whose shapes are incompatible
347-
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
348-
if (
349-
isinstance(smooth_layer, torch.nn.Linear)
350-
and isinstance(balance_layer, torch.nn.Linear)
351-
and balance_name.endswith(".o_proj")
352-
and (
353-
(
354-
smooth_name.endswith(".v_proj")
355-
and smooth_layer.out_features
356-
!= balance_layer.in_features
357-
)
358-
or (
359-
smooth_name.endswith(".qkv_proj")
360-
and smooth_layer.out_features
361-
!= 3 * balance_layer.in_features
362-
)
343+
# Get names using the pre-built mapping
344+
smooth_name = module_to_name.get(smooth_layer)
345+
if smooth_name is None:
346+
continue
347+
348+
# Filter balance layers, skipping incompatible ones
349+
balance_layers = []
350+
balance_names = []
351+
352+
for balance_layer in all_balance_layers:
353+
balance_name = module_to_name.get(balance_layer)
354+
if balance_name is None:
355+
continue
356+
357+
# exclude v_proj->o_proj mappings whose shapes are incompatible
358+
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
359+
if (
360+
isinstance(smooth_layer, torch.nn.Linear)
361+
and isinstance(balance_layer, torch.nn.Linear)
362+
and balance_name.endswith(".o_proj")
363+
and (
364+
(
365+
smooth_name.endswith(".v_proj")
366+
and smooth_layer.out_features
367+
!= balance_layer.in_features
368+
)
369+
or (
370+
smooth_name.endswith(".qkv_proj")
371+
and smooth_layer.out_features
372+
!= 3 * balance_layer.in_features
363373
)
364-
):
365-
num_skipped_mappings += 1
366-
continue
374+
)
375+
):
376+
num_skipped_mappings += 1
377+
continue
367378

368-
balance_layers.append(balance_layer)
369-
balance_names.append(balance_name)
379+
balance_layers.append(balance_layer)
380+
balance_names.append(balance_name)
370381

371382
if len(balance_layers) == 0:
372383
continue
373384

374-
elif len(balance_layers) == 1:
385+
if len(balance_layers) == 1:
375386
# for single balance layer, parent is the balance layer
376-
parent_name, parent = balance_name, balance_layer
387+
parent_name, parent = balance_names[0], balance_layers[0]
377388
else:
378389
# for multiple balance layers, find lowest common parent
379390
parent_name, parent = get_lowest_common_parent(balance_names, model)

0 commit comments

Comments
 (0)