Skip to content

Commit 351568d

Browse files
committed
fixing logic and test update
Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent dc8e3ae commit 351568d

File tree

2 files changed

+61
-50
lines changed

2 files changed

+61
-50
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -313,74 +313,43 @@ def _set_resolved_mappings(self, model: Module) -> None:
313313
into ResolvedMapping objects, resolving regular expressions.
314314
Result is stored in _resolved_mappings.
315315
316-
Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers)
317-
that belong together in the model architecture.
316+
For each activation in the mapping list, we find the corresponding weight to
317+
balance by searching for the longest substring. For instance, if our balance
318+
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
319+
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
320+
repeat for model.layer.1 and so on
318321
"""
319-
# Build a module-to-name mapping for efficient lookups
320-
module_to_name = {module: name for name, module in model.named_modules()}
321-
322322
resolved_mappings: list[ResolvedMapping] = []
323+
module_to_name = {module: name for name, module in model.named_modules()}
323324
for mapping_idx, mapping in enumerate(self.mappings):
324325
num_skipped_mappings = 0
325326

326327
# Use match_modules_set to find coherent sets of modules
327328
target_patterns = (mapping.smooth_layer, *mapping.balance_layers)
328329

329-
for modules_set in (
330+
for smooth_layer, *balance_layers in (
330331
pbar := tqdm(match_modules_set(model, target_patterns, self.ignore))
331332
):
332333
pbar.set_description(
333334
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
334335
f" ({num_skipped_mappings} skipped)"
335336
)
336337

337-
# Unpack the matched set: first is smooth_layer, rest are balance_layers
338-
smooth_layer = modules_set[0]
339-
all_balance_layers = list(modules_set[1:])
340-
341-
# Get names using the pre-built mapping
342338
smooth_name = module_to_name.get(smooth_layer)
343-
if smooth_name is None:
344-
continue
339+
balance_names = [
340+
module_to_name.get(balance_layer)
341+
for balance_layer in balance_layers
342+
]
345343

346-
# Filter balance layers, skipping incompatible ones
347-
balance_layers = []
348-
balance_names = []
349-
350-
for balance_layer in all_balance_layers:
351-
balance_name = module_to_name.get(balance_layer)
352-
if balance_name is None:
353-
continue
354-
355-
# exclude v_proj->o_proj mappings whose shapes are incompatible
356-
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
357-
if (
358-
isinstance(smooth_layer, torch.nn.Linear)
359-
and isinstance(balance_layer, torch.nn.Linear)
360-
and balance_name.endswith(".o_proj")
361-
and (
362-
(
363-
smooth_name.endswith(".v_proj")
364-
and smooth_layer.out_features
365-
!= balance_layer.in_features
366-
)
367-
or (
368-
smooth_name.endswith(".qkv_proj")
369-
and smooth_layer.out_features
370-
!= 3 * balance_layer.in_features
371-
)
372-
)
373-
):
374-
num_skipped_mappings += 1
375-
continue
376-
377-
balance_layers.append(balance_layer)
378-
balance_names.append(balance_name)
344+
all_compatible = _check_layers_are_compatible(
345+
smooth_layer, smooth_name, balance_layers, balance_names
346+
)
379347

380-
if len(balance_layers) == 0:
348+
# skip mapping if any of the balance layers are incompatible
349+
if not all_compatible or len(balance_layers) == 0:
350+
num_skipped_mappings += 1
381351
continue
382-
383-
if len(balance_layers) == 1:
352+
elif len(balance_layers) == 1:
384353
# for single balance layer, parent is the balance layer
385354
parent_name, parent = balance_names[0], balance_layers[0]
386355
else:
@@ -730,6 +699,35 @@ def _assert_all_activations_consumed(self):
730699
raise RuntimeError("Some cached activations were not used")
731700

732701

702+
def _check_layers_are_compatible(
703+
smooth_layer, smooth_name, balance_layers, balance_names
704+
):
705+
"""
706+
returns True if they are all compatible
707+
returns False if any smooth & balance layers are incompatible
708+
"""
709+
for balance_layer, balance_name in zip(balance_layers, balance_names):
710+
# exclude v_proj->o_proj mappings whose shapes are incompatible
711+
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
712+
if (
713+
isinstance(smooth_layer, torch.nn.Linear)
714+
and isinstance(balance_layer, torch.nn.Linear)
715+
and balance_name.endswith(".o_proj")
716+
and (
717+
(
718+
smooth_name.endswith(".v_proj")
719+
and smooth_layer.out_features != balance_layer.in_features
720+
)
721+
or (
722+
smooth_name.endswith(".qkv_proj")
723+
and smooth_layer.out_features != 3 * balance_layer.in_features
724+
)
725+
)
726+
):
727+
return False
728+
return True
729+
730+
733731
def _pseudo_quantize_tensor(
734732
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
735733
):

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,12 @@ def test_set_resolved_mappings():
8585
assert set(mapping.balance_names) == {"decoder.mlp.down_proj"}
8686
assert mapping.parent_name == "decoder.mlp.down_proj"
8787

88-
# make sure we exclude case where o_proj/v_proj shapes are mismatched
8988
awq = AWQModifier(
9089
mappings=[
90+
# make sure we exclude case where o_proj/v_proj shapes are mismatched
9191
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
92+
# make sure we exclude mapping if any balance layers are skipped
93+
AWQMapping("re:.*v_proj", ["re:.*z_proj", "re:.*o_proj"]),
9294
],
9395
scheme="W4A16_ASYM",
9496
)
@@ -101,6 +103,7 @@ def test_set_resolved_mappings():
101103
"q_proj": torch.nn.Linear(4, 2),
102104
"k_proj": torch.nn.Linear(4, 2),
103105
"v_proj": torch.nn.Linear(4, 2),
106+
"z_proj": torch.nn.Linear(2, 4),
104107
"o_proj": torch.nn.Linear(4, 4),
105108
}
106109
)
@@ -109,6 +112,16 @@ def test_set_resolved_mappings():
109112
}
110113
)
111114
awq._set_resolved_mappings(model)
115+
if len(awq._resolved_mappings) > 0:
116+
assert all(
117+
"o_proj" not in name for name in awq._resolved_mappings[0].balance_names
118+
), "should have skipped v->o mapping because o is incompatible"
119+
assert all(
120+
"z_proj" not in name for name in awq._resolved_mappings[0].balance_names
121+
), (
122+
"should have skipped v->[z,o] mapping because o is incompatible even though"
123+
"z is compatible"
124+
)
112125
assert len(awq._resolved_mappings) == 0
113126

114127

0 commit comments

Comments
 (0)