Skip to content

Commit 0f265f9

Browse files
committed
fixes
Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent d6dc289 commit 0f265f9

File tree

3 files changed

+62
-99
lines changed

3 files changed

+62
-99
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
match_modules_set,
1111
match_named_modules,
1212
update_offload_parameter,
13+
get_lowest_common_ancestor_name,
1314
)
1415
from loguru import logger
1516
from pydantic import ConfigDict, PrivateAttr, model_validator
1617
from torch.nn import Module
1718
from tqdm import tqdm
18-
19+
from torch.utils._pytree import tree_flatten
1920
from llmcompressor.core import Event, EventType, State
2021
from llmcompressor.modifiers import Modifier
2122
from llmcompressor.modifiers.awq.mappings import (
@@ -332,12 +333,20 @@ def _set_resolved_mappings(self, model: Module) -> None:
332333
module_to_name[module] = name
333334

334335
for mapping in self.mappings:
335-
target_patterns = (mapping.smooth_layer, *mapping.balance_layers)
336-
337-
for smooth_layer, *balance_layers in match_modules_set(
338-
model, target_patterns, self.ignore
336+
for smooth_layers, *nested_balance_layers in match_modules_set(
337+
model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore
339338
):
339+
assert len(smooth_layers)==1, (
340+
"AWQ mappings need to match a single smoothlayer for each mapping but got "
341+
f"{[module_to_name.get(smooth_layer) for smooth_layer in smooth_layers]} "
342+
f"when matching {mapping.smooth_layer}"
343+
)
344+
smooth_layer = smooth_layers[0]
340345
smooth_name = module_to_name.get(smooth_layer)
346+
347+
#[[b00, b01, b02...], [b10, b11, b12,...], ...] v
348+
# [b00, b01, b02, ..., b10, b11, b12, ...]
349+
balance_layers = tree_flatten(nested_balance_layers)[0]
341350
balance_names = [
342351
module_to_name.get(balance_layer)
343352
for balance_layer in balance_layers
@@ -361,16 +370,17 @@ def _set_resolved_mappings(self, model: Module) -> None:
361370
continue
362371
else:
363372
# for multiple balance layers, find lowest common parent
364-
parent_name, parent = get_lowest_common_module(balance_names, model)
373+
ancestor_name = get_lowest_common_ancestor_name(balance_names)
374+
ancestor, ancestor_name = get_lowest_non_module_list_ancestor(ancestor_name, )
365375

366376
resolved_mappings.append(
367377
ResolvedMapping(
368378
smooth_name,
369379
smooth_layer,
370380
balance_layers,
371381
balance_names=balance_names,
372-
parent=parent,
373-
parent_name=parent_name,
382+
parent=ancestor,
383+
parent_name=ancestor_name,
374384
)
375385
)
376386
self._resolved_mappings = resolved_mappings
@@ -795,45 +805,25 @@ def _accumulate_mean(
795805
return (prev_sum + sum_added) / new_count, new_count
796806

797807

798-
def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]:
808+
def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]:
799809
"""
800-
Given a list of names, returns the lowest-scope common module.
810+
Given a name: foo.bar.baz, finds lowest ancestor that's not a ModuleList
811+
i.e. module_list.module_dict.module_list -> module_list.module_dict
812+
i.e. module_list.module_dict -> module_list.module_dict
813+
(self is an ancestor of self)
801814
802-
NOTE: function excludes modules of type ModuleList, which don't play
815+
NOTE: This is needed because ModuleLists don't play
803816
nicely with hooks because their forward method is never directly
804817
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
805818
are selected based on router output and their forward method is called.
806819
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
807820
808821
Returns name of module and pointer to module
809-
810-
Implementation is a small alteration of os.path.commonprefix
811-
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
812822
"""
813-
# adding "." before and after allows for handling a lot of corner
814-
# cases which were previously mishandled ([case]->prefix->result)
815-
# case 0: single module: [.abc.] -> .abc. -> abc
816-
# case 1: substring modules: [.abc., .ab.] -> .ab -> ""
817-
# case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab
818-
s1 = min(names) + "."
819-
s2 = max(names) + "."
820-
821-
# 1) find longest shared prefix
822-
parent_name = "."
823-
for i, c in enumerate(s1):
824-
if c != s2[i]:
825-
break
826-
parent_name += c
827-
828-
# 2) throw away module name fragment and leading dot
829-
# ".keep.thro" -> "keep"
830-
parent_name = parent_name[1 : parent_name.rfind(".")]
831-
832-
# 3) return first common module that is not a module list
833823
while True:
834-
if parent_name == "":
824+
if name == "":
835825
return "", module
836-
parent = get_layer_by_name(parent_name, module)
837-
if not isinstance(parent, torch.nn.ModuleList):
838-
return parent_name, parent
839-
parent_name = ".".join(parent_name.split(".")[:-1])
826+
module = get_layer_by_name(name, module)
827+
if not isinstance(module, torch.nn.ModuleList):
828+
return name, module
829+
name = ".".join(parent_name.split(".")[:-1])

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
TransformScheme,
1010
apply_transform_config,
1111
)
12+
from torch.utils._pytree import tree_flatten
1213
from compressed_tensors.utils import TorchDtype, get_head_dim
1314
from pydantic import Field, ValidationInfo, field_validator
1415
from transformers import PreTrainedModel
@@ -203,8 +204,10 @@ def _fuse_norms(self, model: PreTrainedModel):
203204
for mapping in self.norm_mappings:
204205
for norm, *linears in match_modules_set(
205206
model, (mapping.norm, *mapping.linears)
206-
):
207-
fuse_norm_linears(norm, linears)
207+
):
208+
# match_modules_set returns a list of lists
209+
assert len(norm) == 1
210+
fuse_norm_linears(norm[0], tree_flatten(linears)[0])
208211

209212
def _create_r1_scheme(self) -> TransformScheme:
210213
return TransformScheme(

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.nn import Linear
66

77
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
8-
from llmcompressor.modifiers.awq.base import get_lowest_common_module
8+
from llmcompressor.modifiers.awq.base import get_lowest_non_module_list_ancestor
99
from llmcompressor.modifiers.factory import ModifierFactory
1010

1111

@@ -47,11 +47,18 @@ def test_set_resolved_mappings():
4747
"o_proj": Linear(4, 4),
4848
}
4949
)
50-
mlp = torch.nn.ModuleDict(
51-
{
52-
"up_proj": Linear(4, 10),
53-
"down_proj": Linear(10, 4),
54-
}
50+
mlp = torch.nn.ModuleList(
51+
"experts": torch.nn.ModuleList(
52+
[
53+
torch.nn.ModuleDict(
54+
{
55+
"gate_proj": Linear(4, 2),
56+
"down_proj": Linear(4, 2),
57+
}
58+
)
59+
for _ in range(3)
60+
]
61+
)
5562
)
5663
model = torch.nn.ModuleDict(
5764
{
@@ -83,8 +90,8 @@ def test_set_resolved_mappings():
8390
assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"}
8491
assert mapping.parent_name == "decoder.self_attn.o_proj"
8592
if "mlp.up_proj" in mapping.smooth_name:
86-
assert set(mapping.balance_names) == {"decoder.mlp.down_proj"}
87-
assert mapping.parent_name == "decoder.mlp.down_proj"
93+
assert set(mapping.balance_names) == {"decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj"}
94+
assert mapping.parent_name == "decoder.mlp.down_proj" # TODODODO
8895

8996
awq = AWQModifier(
9097
mappings=[
@@ -193,72 +200,35 @@ def test_validate():
193200

194201

195202
@pytest.mark.unit
196-
def test_get_lowest_common_module():
197-
mlp = torch.nn.ModuleDict(
203+
def test_get_lowest_non_module_list_ancestor():
204+
model = torch.nn.ModuleDict(
198205
{
199206
"experts": torch.nn.ModuleList(
200207
[
201208
torch.nn.ModuleDict(
202209
{
203210
"gate_proj": Linear(4, 2),
204-
"down_proj": Linear(4, 2),
211+
"down_proj": Linear(2, 4),
205212
}
206213
)
207214
for _ in range(10)
208215
]
209216
)
210217
}
211218
)
212-
self_attn = torch.nn.ModuleDict(
213-
{
214-
"q_proj": Linear(4, 2),
215-
"k_proj": Linear(4, 2),
216-
"v_proj": Linear(4, 2),
217-
"o_proj": Linear(4, 4),
218-
}
219+
220+
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
221+
"", model
219222
)
220-
model = torch.nn.ModuleDict(
221-
{
222-
"embed_tokens": Linear(4, 2),
223-
"decoder": torch.nn.ModuleDict(
224-
{
225-
"self_attn": self_attn,
226-
"mlp": mlp,
227-
}
228-
),
229-
}
230-
)
231-
232-
parent_name, parent = get_lowest_common_module(
233-
["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model
234-
)
235-
assert parent_name == "decoder.mlp" and parent == mlp
236-
237-
parent_name, parent = get_lowest_common_module(
238-
["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model
239-
)
240-
assert parent_name == "decoder.self_attn" and parent == self_attn
223+
assert ancestor_name == "" and ancestor == model
241224

242-
parent_name, parent = get_lowest_common_module(
243-
["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model
225+
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
226+
["experts"], model
244227
)
245-
assert parent_name == "decoder" and parent == model["decoder"]
228+
assert ancestor_name == "" and ancestor == model
246229

247-
parent_name, parent = get_lowest_common_module(
248-
["embed_tokens", "decoder.self_attn.v_proj"], model
230+
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
231+
"experts.1.gate_proj", model
249232
)
250-
assert parent_name == "" and parent == model
233+
assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"]
251234

252-
m = torch.nn.ModuleDict(
253-
{
254-
"abc": Linear(3, 3),
255-
"ab": torch.nn.ModuleDict({"a": Linear(3, 3)}),
256-
"z": Linear(3, 3),
257-
}
258-
)
259-
parent_name, parent = get_lowest_common_module(["abc", "ab"], m)
260-
assert parent_name == ""
261-
parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m)
262-
assert parent_name == "ab"
263-
parent_name, parent = get_lowest_common_module(["z"], m)
264-
assert parent_name == "z"

0 commit comments

Comments
 (0)