Skip to content

Commit 1f10f7d

Browse files
committed
format
Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent f9e4cfa commit 1f10f7d

File tree

2 files changed

+4
-61
lines changed

2 files changed

+4
-61
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3333
from llmcompressor.utils.helpers import calibration_forward_context
3434
from llmcompressor.utils.pytorch.module import (
35-
get_layer_by_name,
3635
get_module_to_name_dict,
3736
)
3837

@@ -365,7 +364,10 @@ def _set_resolved_mappings(self, model: Module) -> None:
365364

366365
ancestor_name = get_lowest_common_ancestor_name(balance_names)
367366
# no ModuleList ancestors
368-
while not isinstance((ancestor := model.get_submodule(ancestor_name)), torch.nn.ModuleList):
367+
while not isinstance(
368+
(ancestor := model.get_submodule(ancestor_name)),
369+
torch.nn.ModuleList,
370+
):
369371
ancestor_name = ancestor_name.rsplit(".", 1)[0]
370372

371373
resolved_mappings.append(
@@ -798,28 +800,3 @@ def _accumulate_mean(
798800
new_count = prev_count + num_added
799801

800802
return (prev_sum + sum_added) / new_count, new_count
801-
802-
803-
def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]:
804-
"""
805-
Given a name and a model, finds lowest ancestor of
806-
named module that's not a ModuleList
807-
i.e. module_list.module_dict.module_list -> module_list.module_dict
808-
i.e. module_list.module_dict -> module_list.module_dict
809-
(self is an ancestor of self)
810-
811-
NOTE: This is needed because ModuleLists don't play
812-
nicely with hooks because their forward method is never directly
813-
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
814-
are selected based on router output and their forward method is called.
815-
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
816-
817-
Returns name of module and pointer to module
818-
"""
819-
while True:
820-
if name == "":
821-
return "", module
822-
current_module = get_layer_by_name(name, module)
823-
if not isinstance(current_module, torch.nn.ModuleList):
824-
return name, current_module
825-
name = ".".join(name.split(".")[:-1])

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch.nn import Linear
66

77
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
8-
from llmcompressor.modifiers.awq.base import get_lowest_non_module_list_ancestor
98
from llmcompressor.modifiers.factory import ModifierFactory
109

1110

@@ -205,39 +204,6 @@ def test_validate():
205204
AWQModifier(scheme="W4A16", duo_scaling="x")
206205

207206

208-
@pytest.mark.unit
209-
def test_get_lowest_non_module_list_ancestor():
210-
model = torch.nn.ModuleDict(
211-
{
212-
"experts": torch.nn.ModuleList(
213-
[
214-
torch.nn.ModuleDict(
215-
{
216-
"gate_proj": Linear(4, 2),
217-
"down_proj": Linear(2, 4),
218-
}
219-
)
220-
for _ in range(10)
221-
]
222-
)
223-
}
224-
)
225-
226-
ancestor_name, ancestor = get_lowest_non_module_list_ancestor("", model)
227-
assert ancestor_name == "" and ancestor == model
228-
229-
ancestor_name, ancestor = get_lowest_non_module_list_ancestor("experts", model)
230-
assert ancestor_name == "" and ancestor == model
231-
232-
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
233-
"experts.1.gate_proj", model
234-
)
235-
assert (
236-
ancestor_name == "experts.1.gate_proj"
237-
and ancestor == model["experts"][1]["gate_proj"]
238-
)
239-
240-
241207
@pytest.mark.unit
242208
def test_moe_multiple_balance_layers():
243209
"""Test AWQ mapping with multiple balance layers in MoE architecture"""

0 commit comments

Comments
 (0)