From 764f512d9dda79bec172087e8048be18bf4841f6 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Nov 2025 20:57:13 +0000 Subject: [PATCH 01/25] [AWQ] small refactor to use match_modules_set Summary: modified _set_resolved_mappings to get smoothing and balance layers at same time. Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 93 ++++++++++++++----------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index dc35a5c02..ccfbc9274 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,6 +7,7 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + match_modules_set, match_named_modules, update_offload_parameter, ) @@ -312,19 +313,22 @@ def _set_resolved_mappings(self, model: Module) -> None: into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. - For each activation in the mapping list, we find the corresponding weight to - balance by searching for the longest substring. For instance, if our balance - weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we - would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and - repeat for model.layer.1 and so on + Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers) + that belong together in the model architecture. """ + # Build a module-to-name mapping for efficient lookups + module_to_name = {module: name for name, module in model.named_modules()} + resolved_mappings: list[ResolvedMapping] = [] for mapping_idx, mapping in enumerate(self.mappings): num_skipped_mappings = 0 - for smooth_name, smooth_layer in ( + # Use match_modules_set to find coherent sets of modules + target_patterns = (mapping.smooth_layer, *mapping.balance_layers) + + for modules_set in ( pbar := tqdm( - match_named_modules(model, [mapping.smooth_layer], self.ignore) + match_modules_set(model, target_patterns, self.ignore) ) ): pbar.set_description( @@ -332,48 +336,55 @@ def _set_resolved_mappings(self, model: Module) -> None: f" ({num_skipped_mappings} skipped)" ) - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) + # Unpack the matched set: first is smooth_layer, rest are balance_layers + smooth_layer = modules_set[0] + all_balance_layers = list(modules_set[1:]) - balance_layers, balance_names = [], [] - for balance_regex in mapping.balance_layers: - # find the submodules that match the activation layer - for balance_suffix, balance_layer in match_named_modules( - smooth_parent, [balance_regex], self.ignore - ): - balance_name = f"{smooth_parent_name}.{balance_suffix}" - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) + # Get names using the pre-built mapping + smooth_name = module_to_name.get(smooth_layer) + if smooth_name is None: + continue + + # Filter balance layers, skipping incompatible ones + balance_layers = [] + balance_names = [] + + for balance_layer in all_balance_layers: + balance_name = module_to_name.get(balance_layer) + if balance_name is None: + continue + + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features + != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features + != 3 * balance_layer.in_features ) - ): - num_skipped_mappings += 1 - continue + ) + ): + num_skipped_mappings += 1 + continue - balance_layers.append(balance_layer) - balance_names.append(balance_name) + balance_layers.append(balance_layer) + balance_names.append(balance_name) if len(balance_layers) == 0: continue - elif len(balance_layers) == 1: + if len(balance_layers) == 1: # for single balance layer, parent is the balance layer - parent_name, parent = balance_name, balance_layer + parent_name, parent = balance_names[0], balance_layers[0] else: # for multiple balance layers, find lowest common parent parent_name, parent = get_lowest_common_parent(balance_names, model) From 53b8c5cfe6ac9413400efa8b09e8fd2bd266516f Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Nov 2025 21:07:36 +0000 Subject: [PATCH 02/25] formatting Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index ccfbc9274..7228805c2 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -327,9 +327,7 @@ def _set_resolved_mappings(self, model: Module) -> None: target_patterns = (mapping.smooth_layer, *mapping.balance_layers) for modules_set in ( - pbar := tqdm( - match_modules_set(model, target_patterns, self.ignore) - ) + pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) ): pbar.set_description( f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" From 172185593ee01f49a4612494e71e87cd85d19edb Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 26 Nov 2025 18:09:44 +0000 Subject: [PATCH 03/25] fixing logic and test update Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 96 +++++++++---------- .../llmcompressor/modifiers/awq/test_base.py | 15 ++- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 7228805c2..fa48fb8b4 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -313,20 +313,21 @@ def _set_resolved_mappings(self, model: Module) -> None: into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. - Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers) - that belong together in the model architecture. + For each activation in the mapping list, we find the corresponding weight to + balance by searching for the longest substring. For instance, if our balance + weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we + would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and + repeat for model.layer.1 and so on """ - # Build a module-to-name mapping for efficient lookups - module_to_name = {module: name for name, module in model.named_modules()} - resolved_mappings: list[ResolvedMapping] = [] + module_to_name = {module: name for name, module in model.named_modules()} for mapping_idx, mapping in enumerate(self.mappings): num_skipped_mappings = 0 # Use match_modules_set to find coherent sets of modules target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - for modules_set in ( + for smooth_layer, *balance_layers in ( pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) ): pbar.set_description( @@ -334,53 +335,21 @@ def _set_resolved_mappings(self, model: Module) -> None: f" ({num_skipped_mappings} skipped)" ) - # Unpack the matched set: first is smooth_layer, rest are balance_layers - smooth_layer = modules_set[0] - all_balance_layers = list(modules_set[1:]) - - # Get names using the pre-built mapping smooth_name = module_to_name.get(smooth_layer) - if smooth_name is None: - continue + balance_names = [ + module_to_name.get(balance_layer) + for balance_layer in balance_layers + ] - # Filter balance layers, skipping incompatible ones - balance_layers = [] - balance_names = [] - - for balance_layer in all_balance_layers: - balance_name = module_to_name.get(balance_layer) - if balance_name is None: - continue - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) - ) - ): - num_skipped_mappings += 1 - continue - - balance_layers.append(balance_layer) - balance_names.append(balance_name) + all_compatible = _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names + ) - if len(balance_layers) == 0: + # skip mapping if any of the balance layers are incompatible + if not all_compatible or len(balance_layers) == 0: + num_skipped_mappings += 1 continue - - if len(balance_layers) == 1: + elif len(balance_layers) == 1: # for single balance layer, parent is the balance layer parent_name, parent = balance_names[0], balance_layers[0] else: @@ -730,6 +699,35 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") +def _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names +): + """ + returns True if they are all compatible + returns False if any smooth & balance layers are incompatible + """ + for balance_layer, balance_name in zip(balance_layers, balance_names): + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features != 3 * balance_layer.in_features + ) + ) + ): + return False + return True + + def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 950ab0f51..83eaaa970 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -85,10 +85,12 @@ def test_set_resolved_mappings(): assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} assert mapping.parent_name == "decoder.mlp.down_proj" - # make sure we exclude case where o_proj/v_proj shapes are mismatched awq = AWQModifier( mappings=[ + # make sure we exclude case where o_proj/v_proj shapes are mismatched AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + # make sure we exclude mapping if any balance layers are skipped + AWQMapping("re:.*v_proj", ["re:.*z_proj", "re:.*o_proj"]), ], scheme="W4A16_ASYM", ) @@ -101,6 +103,7 @@ def test_set_resolved_mappings(): "q_proj": torch.nn.Linear(4, 2), "k_proj": torch.nn.Linear(4, 2), "v_proj": torch.nn.Linear(4, 2), + "z_proj": torch.nn.Linear(2, 4), "o_proj": torch.nn.Linear(4, 4), } ) @@ -109,6 +112,16 @@ def test_set_resolved_mappings(): } ) awq._set_resolved_mappings(model) + if len(awq._resolved_mappings) > 0: + assert all( + "o_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), "should have skipped v->o mapping because o is incompatible" + assert all( + "z_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), ( + "should have skipped v->[z,o] mapping because o is incompatible even though" + "z is compatible" + ) assert len(awq._resolved_mappings) == 0 From ec77bd09cc0693d938257e90ff239becfd7a8f34 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 27 Nov 2025 03:08:53 +0000 Subject: [PATCH 04/25] updates to get_lowest_common_x Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 67 ++++++++++++------- .../llmcompressor/modifiers/awq/test_base.py | 65 +++++++++++------- 2 files changed, 84 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index fa48fb8b4..c4cb2ad5a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -320,21 +320,26 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - module_to_name = {module: name for name, module in model.named_modules()} - for mapping_idx, mapping in enumerate(self.mappings): - num_skipped_mappings = 0 + + module_to_name = {} + for name, module in model.named_modules(): + if module in module_to_name: + logger.info( + f"Warning, {name} and {module_to_name[module]} both " + "share the same module the same module, " + "may have trouble resolving mappings." + ) + module_to_name[module] = name + + + + for mapping in self.mappings: - # Use match_modules_set to find coherent sets of modules target_patterns = (mapping.smooth_layer, *mapping.balance_layers) for smooth_layer, *balance_layers in ( - pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) + match_modules_set(model, target_patterns, self.ignore) ): - pbar.set_description( - f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" - f" ({num_skipped_mappings} skipped)" - ) - smooth_name = module_to_name.get(smooth_layer) balance_names = [ module_to_name.get(balance_layer) @@ -347,14 +352,18 @@ def _set_resolved_mappings(self, model: Module) -> None: # skip mapping if any of the balance layers are incompatible if not all_compatible or len(balance_layers) == 0: - num_skipped_mappings += 1 + logger.info( + f"skipping AWQ for {smooth_name} for mapping {mapping}" + ( + " because found incompatible balance layers" + if not all_compatible else + f" because no balance layers were found" + ) + ) + continue - elif len(balance_layers) == 1: - # for single balance layer, parent is the balance layer - parent_name, parent = balance_names[0], balance_layers[0] else: # for multiple balance layers, find lowest common parent - parent_name, parent = get_lowest_common_parent(balance_names, model) + parent_name, parent = get_lowest_common_module(balance_names, model) resolved_mappings.append( ResolvedMapping( @@ -788,29 +797,41 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: +def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]: """ - Given a list of names, returns the lowest-scope common parent. + Given a list of names, returns the lowest-scope common module. - NOTE: function excludes parents of type ModuleList, which don't play + NOTE: function excludes modules of type ModuleList, which don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 - Returns name of parent and pointer to parent module + Returns name of module and pointer to module Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - s1 = min(names) - s2 = max(names) - parent_name = "" + # adding "." before and after allows for handling a lot of corner + # cases which were previously mishandled ([case]->prefix->result) + # case 0: single module: [.abc.] -> .abc. -> abc + # case 1: substring modules: [.abc., .ab.] -> .ab -> "" + # case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab + s1 = min(names) + "." + s2 = max(names) + "." + + # 1) find longest shared prefix + parent_name = "." for i, c in enumerate(s1): if c != s2[i]: - parent_name = s1[:i].rstrip(".") break + parent_name += c + + # 2) throw away module name fragment and leading dot + # ".keep.thro" -> "keep" + parent_name = parent_name[1:parent_name.rfind(".")] + # 3) return first parent that is not a module list while True: if parent_name == "": return "", module diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 83eaaa970..2cb78fb65 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -2,9 +2,9 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError - +from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_common_parent +from llmcompressor.modifiers.awq.base import get_lowest_common_module from llmcompressor.modifiers.factory import ModifierFactory @@ -40,16 +40,16 @@ def test_set_resolved_mappings(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 4), - "k_proj": torch.nn.Linear(4, 4), - "v_proj": torch.nn.Linear(4, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 4), + "k_proj": Linear(4, 4), + "v_proj": Linear(4, 4), + "o_proj": Linear(4, 4), } ) mlp = torch.nn.ModuleDict( { - "up_proj": torch.nn.Linear(4, 10), - "down_proj": torch.nn.Linear(10, 4), + "up_proj": Linear(4, 10), + "down_proj": Linear(10, 4), } ) model = torch.nn.ModuleDict( @@ -100,11 +100,11 @@ def test_set_resolved_mappings(): { "self_attn": torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "z_proj": torch.nn.Linear(2, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "z_proj": Linear(2, 4), + "o_proj": Linear(4, 4), } ) } @@ -192,15 +192,15 @@ def test_validate(): @pytest.mark.unit -def test_get_lowest_common_parent(): +def test_get_lowest_common_module(): mlp = torch.nn.ModuleDict( { "experts": torch.nn.ModuleList( [ torch.nn.ModuleDict( { - "gate_proj": torch.nn.Linear(4, 2), - "down_proj": torch.nn.Linear(4, 2), + "gate_proj": Linear(4, 2), + "down_proj": Linear(4, 2), } ) for _ in range(10) @@ -210,15 +210,15 @@ def test_get_lowest_common_parent(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "o_proj": Linear(4, 4), } ) model = torch.nn.ModuleDict( { - "embed_tokens": torch.nn.Linear(4, 2), + "embed_tokens": Linear(4, 2), "decoder": torch.nn.ModuleDict( { "self_attn": self_attn, @@ -228,22 +228,37 @@ def test_get_lowest_common_parent(): } ) - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model ) assert parent_name == "decoder.mlp" and parent == mlp - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder.self_attn" and parent == self_attn - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder" and parent == model["decoder"] - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["embed_tokens", "decoder.self_attn.v_proj"], model ) assert parent_name == "" and parent == model + + m = torch.nn.ModuleDict( + { + "abc": Linear(3,3), + "ab": torch.nn.ModuleDict({"a": Linear(3,3)}), + "z": Linear(3,3) + } + ) + parent_name, parent = get_lowest_common_module(["abc", "ab"], m) + assert parent_name == "" + parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m) + assert parent_name == "ab" + parent_name, parent = get_lowest_common_module(["z"], m) + assert parent_name == "z" + From 06fbcd8009e825b62c31329e2f7ae66cd928a782 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 27 Nov 2025 03:26:33 +0000 Subject: [PATCH 05/25] format Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 24 +++++++++---------- .../llmcompressor/modifiers/awq/test_base.py | 8 +++---- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index c4cb2ad5a..4be556027 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -320,7 +320,7 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - + module_to_name = {} for name, module in model.named_modules(): if module in module_to_name: @@ -331,14 +331,11 @@ def _set_resolved_mappings(self, model: Module) -> None: ) module_to_name[module] = name - - for mapping in self.mappings: - target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - for smooth_layer, *balance_layers in ( - match_modules_set(model, target_patterns, self.ignore) + for smooth_layer, *balance_layers in match_modules_set( + model, target_patterns, self.ignore ): smooth_name = module_to_name.get(smooth_layer) balance_names = [ @@ -353,10 +350,11 @@ def _set_resolved_mappings(self, model: Module) -> None: # skip mapping if any of the balance layers are incompatible if not all_compatible or len(balance_layers) == 0: logger.info( - f"skipping AWQ for {smooth_name} for mapping {mapping}" + ( - " because found incompatible balance layers" - if not all_compatible else - f" because no balance layers were found" + f"skipping AWQ for {smooth_name} for mapping {mapping}" + + ( + " because found incompatible balance layers" + if not all_compatible + else " because no balance layers were found" ) ) @@ -812,7 +810,7 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - # adding "." before and after allows for handling a lot of corner + # adding "." before and after allows for handling a lot of corner # cases which were previously mishandled ([case]->prefix->result) # case 0: single module: [.abc.] -> .abc. -> abc # case 1: substring modules: [.abc., .ab.] -> .ab -> "" @@ -829,9 +827,9 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod # 2) throw away module name fragment and leading dot # ".keep.thro" -> "keep" - parent_name = parent_name[1:parent_name.rfind(".")] + parent_name = parent_name[1 : parent_name.rfind(".")] - # 3) return first parent that is not a module list + # 3) return first common module that is not a module list while True: if parent_name == "": return "", module diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 2cb78fb65..e8103f9e3 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -3,6 +3,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError from torch.nn import Linear + from llmcompressor.modifiers.awq import AWQMapping, AWQModifier from llmcompressor.modifiers.awq.base import get_lowest_common_module from llmcompressor.modifiers.factory import ModifierFactory @@ -250,9 +251,9 @@ def test_get_lowest_common_module(): m = torch.nn.ModuleDict( { - "abc": Linear(3,3), - "ab": torch.nn.ModuleDict({"a": Linear(3,3)}), - "z": Linear(3,3) + "abc": Linear(3, 3), + "ab": torch.nn.ModuleDict({"a": Linear(3, 3)}), + "z": Linear(3, 3), } ) parent_name, parent = get_lowest_common_module(["abc", "ab"], m) @@ -261,4 +262,3 @@ def test_get_lowest_common_module(): assert parent_name == "ab" parent_name, parent = get_lowest_common_module(["z"], m) assert parent_name == "z" - From a4ed675eb168121b37125710ce1c8df15f1601c7 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 16:04:41 +0000 Subject: [PATCH 06/25] fixes Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 68 +++++++-------- .../modifiers/transform/spinquant/base.py | 7 +- .../llmcompressor/modifiers/awq/test_base.py | 86 ++++++------------- 3 files changed, 62 insertions(+), 99 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 4be556027..c1003791d 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -10,12 +10,13 @@ match_modules_set, match_named_modules, update_offload_parameter, + get_lowest_common_ancestor_name, ) from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module from tqdm import tqdm - +from torch.utils._pytree import tree_flatten from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.awq.mappings import ( @@ -332,12 +333,20 @@ def _set_resolved_mappings(self, model: Module) -> None: module_to_name[module] = name for mapping in self.mappings: - target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - - for smooth_layer, *balance_layers in match_modules_set( - model, target_patterns, self.ignore + for smooth_layers, *nested_balance_layers in match_modules_set( + model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore ): + assert len(smooth_layers)==1, ( + "AWQ mappings need to match a single smoothlayer for each mapping but got " + f"{[module_to_name.get(smooth_layer) for smooth_layer in smooth_layers]} " + f"when matching {mapping.smooth_layer}" + ) + smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layer) + + #[[b00, b01, b02...], [b10, b11, b12,...], ...] v + # [b00, b01, b02, ..., b10, b11, b12, ...] + balance_layers = tree_flatten(nested_balance_layers)[0] balance_names = [ module_to_name.get(balance_layer) for balance_layer in balance_layers @@ -361,7 +370,8 @@ def _set_resolved_mappings(self, model: Module) -> None: continue else: # for multiple balance layers, find lowest common parent - parent_name, parent = get_lowest_common_module(balance_names, model) + ancestor_name = get_lowest_common_ancestor_name(balance_names) + ancestor, ancestor_name = get_lowest_non_module_list_ancestor(ancestor_name, ) resolved_mappings.append( ResolvedMapping( @@ -369,8 +379,8 @@ def _set_resolved_mappings(self, model: Module) -> None: smooth_layer, balance_layers, balance_names=balance_names, - parent=parent, - parent_name=parent_name, + parent=ancestor, + parent_name=ancestor_name, ) ) self._resolved_mappings = resolved_mappings @@ -795,45 +805,25 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]: +def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: """ - Given a list of names, returns the lowest-scope common module. + Given a name: foo.bar.baz, finds lowest ancestor that's not a ModuleList + i.e. module_list.module_dict.module_list -> module_list.module_dict + i.e. module_list.module_dict -> module_list.module_dict + (self is an ancestor of self) - NOTE: function excludes modules of type ModuleList, which don't play + NOTE: This is needed because ModuleLists don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 Returns name of module and pointer to module - - Implementation is a small alteration of os.path.commonprefix - https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - # adding "." before and after allows for handling a lot of corner - # cases which were previously mishandled ([case]->prefix->result) - # case 0: single module: [.abc.] -> .abc. -> abc - # case 1: substring modules: [.abc., .ab.] -> .ab -> "" - # case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab - s1 = min(names) + "." - s2 = max(names) + "." - - # 1) find longest shared prefix - parent_name = "." - for i, c in enumerate(s1): - if c != s2[i]: - break - parent_name += c - - # 2) throw away module name fragment and leading dot - # ".keep.thro" -> "keep" - parent_name = parent_name[1 : parent_name.rfind(".")] - - # 3) return first common module that is not a module list while True: - if parent_name == "": + if name == "": return "", module - parent = get_layer_by_name(parent_name, module) - if not isinstance(parent, torch.nn.ModuleList): - return parent_name, parent - parent_name = ".".join(parent_name.split(".")[:-1]) + module = get_layer_by_name(name, module) + if not isinstance(module, torch.nn.ModuleList): + return name, module + name = ".".join(parent_name.split(".")[:-1]) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 58f2e2977..50dddf51e 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -9,6 +9,7 @@ TransformScheme, apply_transform_config, ) +from torch.utils._pytree import tree_flatten from compressed_tensors.utils import TorchDtype, get_head_dim from pydantic import Field, ValidationInfo, field_validator from transformers import PreTrainedModel @@ -204,8 +205,10 @@ def _fuse_norms(self, model: PreTrainedModel): for mapping in self.norm_mappings: for norm, *linears in match_modules_set( model, (mapping.norm, *mapping.linears) - ): - fuse_norm_linears(norm, linears) + ): + # match_modules_set returns a list of lists + assert len(norm) == 1 + fuse_norm_linears(norm[0], tree_flatten(linears)[0]) def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index e8103f9e3..119dfc11a 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -5,7 +5,7 @@ from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_common_module +from llmcompressor.modifiers.awq.base import get_lowest_non_module_list_ancestor from llmcompressor.modifiers.factory import ModifierFactory @@ -47,11 +47,18 @@ def test_set_resolved_mappings(): "o_proj": Linear(4, 4), } ) - mlp = torch.nn.ModuleDict( - { - "up_proj": Linear(4, 10), - "down_proj": Linear(10, 4), - } + mlp = torch.nn.ModuleList( + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": Linear(4, 2), + "down_proj": Linear(4, 2), + } + ) + for _ in range(3) + ] + ) ) model = torch.nn.ModuleDict( { @@ -83,8 +90,8 @@ def test_set_resolved_mappings(): assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"} assert mapping.parent_name == "decoder.self_attn.o_proj" if "mlp.up_proj" in mapping.smooth_name: - assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} - assert mapping.parent_name == "decoder.mlp.down_proj" + assert set(mapping.balance_names) == {"decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj"} + assert mapping.parent_name == "decoder.mlp.down_proj" # TODODODO awq = AWQModifier( mappings=[ @@ -193,15 +200,15 @@ def test_validate(): @pytest.mark.unit -def test_get_lowest_common_module(): - mlp = torch.nn.ModuleDict( +def test_get_lowest_non_module_list_ancestor(): + model = torch.nn.ModuleDict( { "experts": torch.nn.ModuleList( [ torch.nn.ModuleDict( { "gate_proj": Linear(4, 2), - "down_proj": Linear(4, 2), + "down_proj": Linear(2, 4), } ) for _ in range(10) @@ -209,56 +216,19 @@ def test_get_lowest_common_module(): ) } ) - self_attn = torch.nn.ModuleDict( - { - "q_proj": Linear(4, 2), - "k_proj": Linear(4, 2), - "v_proj": Linear(4, 2), - "o_proj": Linear(4, 4), - } + + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + "", model ) - model = torch.nn.ModuleDict( - { - "embed_tokens": Linear(4, 2), - "decoder": torch.nn.ModuleDict( - { - "self_attn": self_attn, - "mlp": mlp, - } - ), - } - ) - - parent_name, parent = get_lowest_common_module( - ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model - ) - assert parent_name == "decoder.mlp" and parent == mlp - - parent_name, parent = get_lowest_common_module( - ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model - ) - assert parent_name == "decoder.self_attn" and parent == self_attn + assert ancestor_name == "" and ancestor == model - parent_name, parent = get_lowest_common_module( - ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + ["experts"], model ) - assert parent_name == "decoder" and parent == model["decoder"] + assert ancestor_name == "" and ancestor == model - parent_name, parent = get_lowest_common_module( - ["embed_tokens", "decoder.self_attn.v_proj"], model + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + "experts.1.gate_proj", model ) - assert parent_name == "" and parent == model + assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"] - m = torch.nn.ModuleDict( - { - "abc": Linear(3, 3), - "ab": torch.nn.ModuleDict({"a": Linear(3, 3)}), - "z": Linear(3, 3), - } - ) - parent_name, parent = get_lowest_common_module(["abc", "ab"], m) - assert parent_name == "" - parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m) - assert parent_name == "ab" - parent_name, parent = get_lowest_common_module(["z"], m) - assert parent_name == "z" From b582ab84265cb77bf162a39ac8604e44c2080078 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 18:44:36 +0000 Subject: [PATCH 07/25] tests Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 13 +-- .../llmcompressor/modifiers/awq/test_base.py | 99 ++++++++++++++++--- 2 files changed, 90 insertions(+), 22 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index c1003791d..fc7b84fc5 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -371,7 +371,7 @@ def _set_resolved_mappings(self, model: Module) -> None: else: # for multiple balance layers, find lowest common parent ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor, ancestor_name = get_lowest_non_module_list_ancestor(ancestor_name, ) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor(ancestor_name, model) resolved_mappings.append( ResolvedMapping( @@ -807,7 +807,8 @@ def _accumulate_mean( def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: """ - Given a name: foo.bar.baz, finds lowest ancestor that's not a ModuleList + Given a name and a model, finds lowest ancestor of + named module that's not a ModuleList i.e. module_list.module_dict.module_list -> module_list.module_dict i.e. module_list.module_dict -> module_list.module_dict (self is an ancestor of self) @@ -823,7 +824,7 @@ def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Modu while True: if name == "": return "", module - module = get_layer_by_name(name, module) - if not isinstance(module, torch.nn.ModuleList): - return name, module - name = ".".join(parent_name.split(".")[:-1]) + current_module = get_layer_by_name(name, module) + if not isinstance(current_module, torch.nn.ModuleList): + return name, current_module + name = ".".join(name.split(".")[:-1]) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 119dfc11a..b4161e7b3 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -47,18 +47,21 @@ def test_set_resolved_mappings(): "o_proj": Linear(4, 4), } ) - mlp = torch.nn.ModuleList( - "experts": torch.nn.ModuleList( - [ - torch.nn.ModuleDict( - { - "gate_proj": Linear(4, 2), - "down_proj": Linear(4, 2), - } - ) - for _ in range(3) - ] - ) + mlp = torch.nn.ModuleDict( + { + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": Linear(4, 2), + "up_proj": Linear(4, 2), + "down_proj": Linear(2, 4), + } + ) + for _ in range(3) + ] + ) + } ) model = torch.nn.ModuleDict( { @@ -89,9 +92,12 @@ def test_set_resolved_mappings(): if "self_attn.v_proj" in mapping.smooth_name: assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"} assert mapping.parent_name == "decoder.self_attn.o_proj" - if "mlp.up_proj" in mapping.smooth_name: - assert set(mapping.balance_names) == {"decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj"} - assert mapping.parent_name == "decoder.mlp.down_proj" # TODODODO + if "mlp.experts" in mapping.smooth_name and "up_proj" in mapping.smooth_name: + expert_idx = mapping.smooth_name.split(".")[-2] + expected_down_proj = f"decoder.mlp.experts.{expert_idx}.down_proj" + assert set(mapping.balance_names) == {expected_down_proj} + assert mapping.parent_name == expected_down_proj + assert mapping.parent == mlp["experts"][int(expert_idx)]["down_proj"] awq = AWQModifier( mappings=[ @@ -223,7 +229,7 @@ def test_get_lowest_non_module_list_ancestor(): assert ancestor_name == "" and ancestor == model ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - ["experts"], model + "experts", model ) assert ancestor_name == "" and ancestor == model @@ -232,3 +238,64 @@ def test_get_lowest_non_module_list_ancestor(): ) assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"] + +@pytest.mark.unit +def test_moe_multiple_balance_layers(): + """Test AWQ mapping with multiple balance layers in MoE architecture""" + awq = AWQModifier( + mappings=[ + # Map input_layernorm to multiple experts' gate_proj and up_proj + AWQMapping( + "re:.*input_layernorm", + ["re:.*gate_proj", "re:.*up_proj"], + ), + ], + scheme="W4A16_ASYM", + ) + + # Create a simplified MoE model structure + mlp = torch.nn.ModuleDict( + { + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": Linear(4, 4), + "up_proj": Linear(4, 4), + "down_proj": Linear(4, 4), + } + ) + for _ in range(2) + ] + ) + } + ) + model = torch.nn.ModuleDict( + { + "layer": torch.nn.ModuleDict( + { + "input_layernorm": torch.nn.LayerNorm(4), + "mlp": mlp, + } + ) + } + ) + + awq._set_resolved_mappings(model) + + # Should have one mapping for input_layernorm + assert len(awq._resolved_mappings) == 1 + mapping = awq._resolved_mappings[0] + + # Should map to all gate_proj and up_proj across all experts + expected_balance_names = { + "layer.mlp.experts.0.gate_proj", + "layer.mlp.experts.0.up_proj", + "layer.mlp.experts.1.gate_proj", + "layer.mlp.experts.1.up_proj", + } + assert set(mapping.balance_names) == expected_balance_names + + assert mapping.parent_name == "layer.mlp" + assert mapping.parent == mlp + From cdb78d51b9ee4b963125a5487b5a9f1158972c2a Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 20:27:15 +0000 Subject: [PATCH 08/25] fixes and formatting Summary fix smoothquant logic to align with AWQ Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 38 +++++++--------- .../modifiers/smoothquant/base.py | 44 +++++++++---------- .../modifiers/smoothquant/utils.py | 4 +- .../modifiers/transform/spinquant/base.py | 4 +- src/llmcompressor/utils/pytorch/module.py | 14 ++++++ .../llmcompressor/modifiers/awq/test_base.py | 16 +++---- 6 files changed, 62 insertions(+), 58 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index fc7b84fc5..cf74dc23c 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,16 +7,17 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + get_lowest_common_ancestor_name, match_modules_set, match_named_modules, update_offload_parameter, - get_lowest_common_ancestor_name, ) from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module -from tqdm import tqdm from torch.utils._pytree import tree_flatten +from tqdm import tqdm + from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.awq.mappings import ( @@ -30,7 +31,10 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context -from llmcompressor.utils.pytorch.module import get_layer_by_name +from llmcompressor.utils.pytorch.module import ( + get_layer_by_name, + get_module_to_name_dict, +) __all__ = ["AWQModifier"] @@ -321,30 +325,20 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - - module_to_name = {} - for name, module in model.named_modules(): - if module in module_to_name: - logger.info( - f"Warning, {name} and {module_to_name[module]} both " - "share the same module the same module, " - "may have trouble resolving mappings." - ) - module_to_name[module] = name - + module_to_name = get_module_to_name_dict(model) for mapping in self.mappings: for smooth_layers, *nested_balance_layers in match_modules_set( model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore ): - assert len(smooth_layers)==1, ( - "AWQ mappings need to match a single smoothlayer for each mapping but got " - f"{[module_to_name.get(smooth_layer) for smooth_layer in smooth_layers]} " - f"when matching {mapping.smooth_layer}" + assert len(smooth_layers) == 1, ( + "AWQ mappings need to match a single smoothlayer for each " + f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" ) smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layer) - #[[b00, b01, b02...], [b10, b11, b12,...], ...] v + # [[b00, b01, b02...], [b10, b11, b12,...], ...] v # [b00, b01, b02, ..., b10, b11, b12, ...] balance_layers = tree_flatten(nested_balance_layers)[0] balance_names = [ @@ -371,7 +365,9 @@ def _set_resolved_mappings(self, model: Module) -> None: else: # for multiple balance layers, find lowest common parent ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor_name, ancestor = get_lowest_non_module_list_ancestor(ancestor_name, model) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + ancestor_name, model + ) resolved_mappings.append( ResolvedMapping( @@ -807,7 +803,7 @@ def _accumulate_mean( def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: """ - Given a name and a model, finds lowest ancestor of + Given a name and a model, finds lowest ancestor of named module that's not a ModuleList i.e. module_list.module_dict.module_list -> module_list.module_dict i.e. module_list.module_dict -> module_list.module_dict diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index dcefa2fa4..ea5939c57 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple import torch -from compressed_tensors.utils import align_module_device, match_named_modules +from compressed_tensors.utils import align_module_device, match_modules_set from loguru import logger from pydantic import ConfigDict, Field from torch.nn import Module +from torch.utils._pytree import tree_flatten from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier @@ -14,7 +15,7 @@ handle_mapping_resolution_errors, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent -from llmcompressor.utils.pytorch.module import get_layer_by_name +from llmcompressor.utils.pytorch.module import get_module_to_name_dict MINIMUM_SMOOTHING_SCALE = 1e-5 @@ -95,7 +96,7 @@ class SmoothQuantModifier(Modifier): """ smoothing_strength: float = 0.5 - mappings: Optional[List[Union[Tuple, List]]] = None + mappings: Optional[List[Tuple[List[str], str]]] = None ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None @@ -198,27 +199,22 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: be balanced. """ resolved_mappings = [] - for to_balance, to_smooth in self.mappings: - to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth - - for smooth_name, smooth_layer in match_named_modules( - model, to_smooth_list, self.ignore + module_to_name = get_module_to_name_dict(model) + for mapping in self.mappings: + for *nested_balance_layers, smooth_layers in match_modules_set( + model, tree_flatten(mapping)[0], self.ignore ): - # Search for balance layers within the parent scope - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) - - balance_layers = [ - balance_layer - for _, balance_layer in match_named_modules( - smooth_parent, to_balance, self.ignore - ) - ] - - if balance_layers: - resolved_mappings.append( - SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) - ) + assert len(smooth_layers) == 1, ( + "SmoothQuant mappings must match a single smooth layer for each " + f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" + ) + smooth_layer = smooth_layers[0] + smooth_name = module_to_name.get(smooth_layers[0]) + balance_layers = tree_flatten(nested_balance_layers)[0] + resolved_mappings.append( + SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) + ) return resolved_mappings diff --git a/src/llmcompressor/modifiers/smoothquant/utils.py b/src/llmcompressor/modifiers/smoothquant/utils.py index 8ab38f633..62f544dae 100644 --- a/src/llmcompressor/modifiers/smoothquant/utils.py +++ b/src/llmcompressor/modifiers/smoothquant/utils.py @@ -1,6 +1,6 @@ import functools from collections import namedtuple -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple from loguru import logger @@ -10,7 +10,7 @@ "DEFAULT_SMOOTHQUANT_MAPPINGS", ] -LayerMapType = Tuple[Union[List[str], str], Union[List[str], str]] +LayerMapType = Tuple[List[str], str] LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"]) DEFAULT_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [ diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 50dddf51e..0a4a2a54f 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -9,9 +9,9 @@ TransformScheme, apply_transform_config, ) -from torch.utils._pytree import tree_flatten from compressed_tensors.utils import TorchDtype, get_head_dim from pydantic import Field, ValidationInfo, field_validator +from torch.utils._pytree import tree_flatten from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State @@ -205,7 +205,7 @@ def _fuse_norms(self, model: PreTrainedModel): for mapping in self.norm_mappings: for norm, *linears in match_modules_set( model, (mapping.norm, *mapping.linears) - ): + ): # match_modules_set returns a list of lists assert len(norm) == 1 fuse_norm_linears(norm[0], tree_flatten(linears)[0]) diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 6d2152fe4..fc4da3d04 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -10,6 +10,7 @@ import torch from compressed_tensors import InternalModule from compressed_tensors.quantization.utils import is_module_quantized +from loguru import logger from torch.nn import Linear, Module, Parameter from torch.nn.modules.conv import _ConvNd from transformers import PreTrainedModel @@ -369,3 +370,16 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module: if not layer_name: return module return attrgetter(layer_name)(module) + + +def get_module_to_name_dict(model: Module) -> dict[Module:str]: + module_to_name = {} + for name, module in model.named_modules(): + if module in module_to_name: + logger.info( + f"Warning, {name} and {module_to_name[module]} both " + "share the same module the same module, " + "may have trouble resolving mappings." + ) + module_to_name[module] = name + return module_to_name diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index b4161e7b3..5ed6594ee 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -222,21 +222,20 @@ def test_get_lowest_non_module_list_ancestor(): ) } ) - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - "", model - ) + + ancestor_name, ancestor = get_lowest_non_module_list_ancestor("", model) assert ancestor_name == "" and ancestor == model - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - "experts", model - ) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor("experts", model) assert ancestor_name == "" and ancestor == model ancestor_name, ancestor = get_lowest_non_module_list_ancestor( "experts.1.gate_proj", model ) - assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"] + assert ( + ancestor_name == "experts.1.gate_proj" + and ancestor == model["experts"][1]["gate_proj"] + ) @pytest.mark.unit @@ -298,4 +297,3 @@ def test_moe_multiple_balance_layers(): assert mapping.parent_name == "layer.mlp" assert mapping.parent == mlp - From 6fc873de6e259ee9984d9380af13b929a6ed1ee7 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 20:55:30 +0000 Subject: [PATCH 09/25] fixing tests to old versions Summary Signed-off-by: HDCharles --- .../logarithmic_equalization/test_pytorch.py | 11 ++++----- .../modifiers/smoothquant/test_pytorch.py | 24 +++++++++---------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index f3e948469..3f111aa74 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -8,18 +8,15 @@ @pytest.mark.unit def test_log_equalization_mapping(state): - # Use regex patterns with parent-scoped search - # Searches for balance layers within the parent of smooth layer - mappings = [(["re:^fc2$"], "re:.*block1\\.fc1$")] + mappings = [(["seq.fc2"], "seq.block1.fc1")] modifier = LogarithmicEqualizationModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - assert len(modifier.resolved_mappings_) == 1 + assert len(modifier.resolved_mappings_) == len(mappings) mapping = modifier.resolved_mappings_[0] - assert mapping.smooth_name == "seq.block1.fc1" + assert mapping.smooth_name == mappings[0][1] assert isinstance(mapping.smooth_layer, Linear) - assert len(mapping.balance_layers) == 1 - assert isinstance(mapping.balance_layers[0], Linear) + assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index aefa6b957..0abe22523 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -4,23 +4,23 @@ from llmcompressor.modifiers.smoothquant import SmoothQuantModifier +import pytest +from torch.nn import Linear + +from llmcompressor.modifiers.smoothquant import SmoothQuantModifier + + @pytest.mark.unit def test_smooth_quant_mapping(state): - # Use regex patterns with parent-scoped search - # ^fc1$ matches only direct child "fc1", not nested "block1.fc1" - mappings = [(["re:^fc1$"], "re:.*fc2$")] + mappings = [(["seq.fc1"], "seq.fc2")] modifier = SmoothQuantModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - # Should match seq.fc2 and block1.fc2 (both end with fc2) - assert len(modifier.resolved_mappings_) == 2 + assert len(modifier.resolved_mappings_) == len(mappings) - # Verify seq.fc2 mapping - should find only seq.fc1 (direct child) - seq_mapping = [ - m for m in modifier.resolved_mappings_ if m.smooth_name == "seq.fc2" - ][0] - assert isinstance(seq_mapping.smooth_layer, Linear) - assert len(seq_mapping.balance_layers) == 1 - assert isinstance(seq_mapping.balance_layers[0], Linear) + mapping = modifier.resolved_mappings_[0] + assert mapping.smooth_name == mappings[0][1] + assert isinstance(mapping.smooth_layer, Linear) + assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file From 5ae5ad0b4d9d0a037fb027ccc6a79cba5a309c7f Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 21:11:42 +0000 Subject: [PATCH 10/25] formatting Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 11 +++++------ src/llmcompressor/utils/pytorch/module.py | 3 +-- .../logarithmic_equalization/test_pytorch.py | 2 +- .../pytorch/modifiers/smoothquant/test_pytorch.py | 8 +------- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index cf74dc23c..47c629f11 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -362,12 +362,11 @@ def _set_resolved_mappings(self, model: Module) -> None: ) continue - else: - # for multiple balance layers, find lowest common parent - ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - ancestor_name, model - ) + + ancestor_name = get_lowest_common_ancestor_name(balance_names) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + ancestor_name, model + ) resolved_mappings.append( ResolvedMapping( diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index fc4da3d04..19326c3a9 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -378,8 +378,7 @@ def get_module_to_name_dict(model: Module) -> dict[Module:str]: if module in module_to_name: logger.info( f"Warning, {name} and {module_to_name[module]} both " - "share the same module the same module, " - "may have trouble resolving mappings." + "share the same module, which can result in unexpected behavior" ) module_to_name[module] = name return module_to_name diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index 3f111aa74..49d0cdda3 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -19,4 +19,4 @@ def test_log_equalization_mapping(state): mapping = modifier.resolved_mappings_[0] assert mapping.smooth_name == mappings[0][1] assert isinstance(mapping.smooth_layer, Linear) - assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file + assert isinstance(mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index 0abe22523..ee8844041 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -4,12 +4,6 @@ from llmcompressor.modifiers.smoothquant import SmoothQuantModifier -import pytest -from torch.nn import Linear - -from llmcompressor.modifiers.smoothquant import SmoothQuantModifier - - @pytest.mark.unit def test_smooth_quant_mapping(state): mappings = [(["seq.fc1"], "seq.fc2")] @@ -23,4 +17,4 @@ def test_smooth_quant_mapping(state): mapping = modifier.resolved_mappings_[0] assert mapping.smooth_name == mappings[0][1] assert isinstance(mapping.smooth_layer, Linear) - assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file + assert isinstance(mapping.balance_layers[0], Linear) From 0f1ffc2bccce51b615bb499840c99c9e75e8583a Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 21:54:05 +0000 Subject: [PATCH 11/25] comments Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 47c629f11..eda61f7ce 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -364,9 +364,9 @@ def _set_resolved_mappings(self, model: Module) -> None: continue ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - ancestor_name, model - ) + # no ModuleList ancestors + while not isinstance((ancestor := model.get_submodule(ancestor_name)), torch.nn.ModuleList): + ancestor_name = ancestor_name.rsplit(".", 1)[0] resolved_mappings.append( ResolvedMapping( From 23b1835a05205cac70a4fad303aee282bb184569 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 21:56:12 +0000 Subject: [PATCH 12/25] format Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 31 +++-------------- .../llmcompressor/modifiers/awq/test_base.py | 34 ------------------- 2 files changed, 4 insertions(+), 61 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index eda61f7ce..62a832ef3 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -32,7 +32,6 @@ from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( - get_layer_by_name, get_module_to_name_dict, ) @@ -365,7 +364,10 @@ def _set_resolved_mappings(self, model: Module) -> None: ancestor_name = get_lowest_common_ancestor_name(balance_names) # no ModuleList ancestors - while not isinstance((ancestor := model.get_submodule(ancestor_name)), torch.nn.ModuleList): + while not isinstance( + (ancestor := model.get_submodule(ancestor_name)), + torch.nn.ModuleList, + ): ancestor_name = ancestor_name.rsplit(".", 1)[0] resolved_mappings.append( @@ -798,28 +800,3 @@ def _accumulate_mean( new_count = prev_count + num_added return (prev_sum + sum_added) / new_count, new_count - - -def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: - """ - Given a name and a model, finds lowest ancestor of - named module that's not a ModuleList - i.e. module_list.module_dict.module_list -> module_list.module_dict - i.e. module_list.module_dict -> module_list.module_dict - (self is an ancestor of self) - - NOTE: This is needed because ModuleLists don't play - nicely with hooks because their forward method is never directly - called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts - are selected based on router output and their forward method is called. - https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 - - Returns name of module and pointer to module - """ - while True: - if name == "": - return "", module - current_module = get_layer_by_name(name, module) - if not isinstance(current_module, torch.nn.ModuleList): - return name, current_module - name = ".".join(name.split(".")[:-1]) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 5ed6594ee..5d82da29a 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -5,7 +5,6 @@ from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_non_module_list_ancestor from llmcompressor.modifiers.factory import ModifierFactory @@ -205,39 +204,6 @@ def test_validate(): AWQModifier(scheme="W4A16", duo_scaling="x") -@pytest.mark.unit -def test_get_lowest_non_module_list_ancestor(): - model = torch.nn.ModuleDict( - { - "experts": torch.nn.ModuleList( - [ - torch.nn.ModuleDict( - { - "gate_proj": Linear(4, 2), - "down_proj": Linear(2, 4), - } - ) - for _ in range(10) - ] - ) - } - ) - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor("", model) - assert ancestor_name == "" and ancestor == model - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor("experts", model) - assert ancestor_name == "" and ancestor == model - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - "experts.1.gate_proj", model - ) - assert ( - ancestor_name == "experts.1.gate_proj" - and ancestor == model["experts"][1]["gate_proj"] - ) - - @pytest.mark.unit def test_moe_multiple_balance_layers(): """Test AWQ mapping with multiple balance layers in MoE architecture""" From 5eb239fdb9d0964662a0f90e6024cab7871edb12 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 5 Dec 2025 22:01:41 +0000 Subject: [PATCH 13/25] tree flatten -> tree_leaves Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/smoothquant/base.py | 6 +++--- src/llmcompressor/modifiers/transform/spinquant/base.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index ea5939c57..5fbeb2d7c 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -6,7 +6,7 @@ from loguru import logger from pydantic import ConfigDict, Field from torch.nn import Module -from torch.utils._pytree import tree_flatten +from torch.utils._pytree import tree_leaves from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier @@ -202,7 +202,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: module_to_name = get_module_to_name_dict(model) for mapping in self.mappings: for *nested_balance_layers, smooth_layers in match_modules_set( - model, tree_flatten(mapping)[0], self.ignore + model, tree_leaves(mapping), self.ignore ): assert len(smooth_layers) == 1, ( "SmoothQuant mappings must match a single smooth layer for each " @@ -211,7 +211,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: ) smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layers[0]) - balance_layers = tree_flatten(nested_balance_layers)[0] + balance_layers = tree_leaves(nested_balance_layers) resolved_mappings.append( SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) ) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 0a4a2a54f..d61fffec6 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -11,7 +11,7 @@ ) from compressed_tensors.utils import TorchDtype, get_head_dim from pydantic import Field, ValidationInfo, field_validator -from torch.utils._pytree import tree_flatten +from torch.utils._pytree import tree_leaves from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State @@ -208,7 +208,7 @@ def _fuse_norms(self, model: PreTrainedModel): ): # match_modules_set returns a list of lists assert len(norm) == 1 - fuse_norm_linears(norm[0], tree_flatten(linears)[0]) + fuse_norm_linears(norm[0], tree_leaves(linears)) def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( From a80f926c607b04611d691b63599a896b5f54ae22 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 5 Dec 2025 22:16:29 +0000 Subject: [PATCH 14/25] fix infinite loop and the asserts Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 46 ++++++++++++------- .../modifiers/smoothquant/base.py | 11 +++-- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 62a832ef3..23f77fce2 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -15,7 +15,7 @@ from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module -from torch.utils._pytree import tree_flatten +from torch.utils._pytree import tree_leaves from tqdm import tqdm from llmcompressor.core import Event, EventType, State @@ -32,7 +32,7 @@ from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( - get_module_to_name_dict, + get_module_to_name_dict, get_layer_by_name ) __all__ = ["AWQModifier"] @@ -329,17 +329,18 @@ def _set_resolved_mappings(self, model: Module) -> None: for smooth_layers, *nested_balance_layers in match_modules_set( model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore ): - assert len(smooth_layers) == 1, ( - "AWQ mappings need to match a single smoothlayer for each " - f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}" - f" for mapping: {mapping}" - ) + if len(smooth_layers)>1: + raise ValueError( + "AWQ needs to match a single smoothlayer for each mapping but " + f"got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" + ) smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layer) # [[b00, b01, b02...], [b10, b11, b12,...], ...] v # [b00, b01, b02, ..., b10, b11, b12, ...] - balance_layers = tree_flatten(nested_balance_layers)[0] + balance_layers = tree_leaves(nested_balance_layers) balance_names = [ module_to_name.get(balance_layer) for balance_layer in balance_layers @@ -351,7 +352,7 @@ def _set_resolved_mappings(self, model: Module) -> None: # skip mapping if any of the balance layers are incompatible if not all_compatible or len(balance_layers) == 0: - logger.info( + logger.warning( f"skipping AWQ for {smooth_name} for mapping {mapping}" + ( " because found incompatible balance layers" @@ -362,13 +363,9 @@ def _set_resolved_mappings(self, model: Module) -> None: continue - ancestor_name = get_lowest_common_ancestor_name(balance_names) - # no ModuleList ancestors - while not isinstance( - (ancestor := model.get_submodule(ancestor_name)), - torch.nn.ModuleList, - ): - ancestor_name = ancestor_name.rsplit(".", 1)[0] + ancestor_name, ancestor = get_lowest_ancestor_with_avoid( + balance_names, model, torch.nn.ModuleList + ) resolved_mappings.append( ResolvedMapping( @@ -741,6 +738,23 @@ def _check_layers_are_compatible( return False return True +def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Module): + """ + get lowest ancestor that is not the avoided class/type + + NOTE: primarily used to exclude parents of type ModuleList, which don't play + nicely with hooks because their forward method is never directly + called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts + are selected based on router output and their forward method is called. + https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 + """ + while True: + if name == "": + return "", model + ancestor = get_layer_by_name(name, model) + if not isinstance(ancestor, avoid): + return name, ancestor + name = ".".join(name.split(".")[:-1]) def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 5fbeb2d7c..9840e520e 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -204,11 +204,12 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: for *nested_balance_layers, smooth_layers in match_modules_set( model, tree_leaves(mapping), self.ignore ): - assert len(smooth_layers) == 1, ( - "SmoothQuant mappings must match a single smooth layer for each " - f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}" - f" for mapping: {mapping}" - ) + if len(smooth_layers)>1: + raise ValueError ( + "SmoothQuant must match a single smooth layer for each mapping" + f" but got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" + ) smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layers[0]) balance_layers = tree_leaves(nested_balance_layers) From a245c384ab269fc251bb3f569c4d0bd40a964fc4 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 5 Dec 2025 22:22:40 +0000 Subject: [PATCH 15/25] formatting Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 8 +++++--- src/llmcompressor/modifiers/smoothquant/base.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 23f77fce2..3d3254b9b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,7 +7,6 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, - get_lowest_common_ancestor_name, match_modules_set, match_named_modules, update_offload_parameter, @@ -32,7 +31,8 @@ from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( - get_module_to_name_dict, get_layer_by_name + get_layer_by_name, + get_module_to_name_dict, ) __all__ = ["AWQModifier"] @@ -329,7 +329,7 @@ def _set_resolved_mappings(self, model: Module) -> None: for smooth_layers, *nested_balance_layers in match_modules_set( model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore ): - if len(smooth_layers)>1: + if len(smooth_layers) > 1: raise ValueError( "AWQ needs to match a single smoothlayer for each mapping but " f"got {[module_to_name.get(s) for s in smooth_layers]}" @@ -738,6 +738,7 @@ def _check_layers_are_compatible( return False return True + def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Module): """ get lowest ancestor that is not the avoided class/type @@ -756,6 +757,7 @@ def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Modu return name, ancestor name = ".".join(name.split(".")[:-1]) + def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 9840e520e..98b6d0f09 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -204,8 +204,8 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: for *nested_balance_layers, smooth_layers in match_modules_set( model, tree_leaves(mapping), self.ignore ): - if len(smooth_layers)>1: - raise ValueError ( + if len(smooth_layers) > 1: + raise ValueError( "SmoothQuant must match a single smooth layer for each mapping" f" but got {[module_to_name.get(s) for s in smooth_layers]}" f" for mapping: {mapping}" From 3af64fe5f467d6e42c2f744f356ce20729528f9b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 8 Dec 2025 14:51:59 +0000 Subject: [PATCH 16/25] fix default for get_lowest_ancestor and correct common name Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 6 ++++-- src/llmcompressor/utils/pytorch/module.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 3d3254b9b..819803191 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,6 +7,7 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + get_lowest_common_ancestor_name, match_modules_set, match_named_modules, update_offload_parameter, @@ -363,8 +364,9 @@ def _set_resolved_mappings(self, model: Module) -> None: continue + ancestor_name = get_lowest_common_ancestor_name(balance_names) ancestor_name, ancestor = get_lowest_ancestor_with_avoid( - balance_names, model, torch.nn.ModuleList + ancestor_name, model, torch.nn.ModuleList ) resolved_mappings.append( @@ -739,7 +741,7 @@ def _check_layers_are_compatible( return True -def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Module): +def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.ModuleList): """ get lowest ancestor that is not the avoided class/type diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 19326c3a9..1be8d2676 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -369,6 +369,7 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module: """ if not layer_name: return module + print("CCC", layer_name) return attrgetter(layer_name)(module) From 4f5fd598ecd4db9161b12bb85252f479724c5078 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Mon, 8 Dec 2025 10:52:58 -0500 Subject: [PATCH 17/25] Apply suggestion from @kylesayrs Co-authored-by: Kyle Sayers Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com> --- src/llmcompressor/modifiers/awq/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 819803191..72526a95b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -754,7 +754,7 @@ def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Modu while True: if name == "": return "", model - ancestor = get_layer_by_name(name, model) + ancestor = model.get_submodule(name) if not isinstance(ancestor, avoid): return name, ancestor name = ".".join(name.split(".")[:-1]) From b271f29187d475d83533c8c165e1eacde31d2713 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 8 Dec 2025 13:26:25 -0500 Subject: [PATCH 18/25] fix serializatin issue Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/smoothquant/base.py | 4 ++-- src/llmcompressor/utils/pytorch/module.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 98b6d0f09..3eb4c66ba 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Sequence import torch from compressed_tensors.utils import align_module_device, match_modules_set @@ -96,7 +96,7 @@ class SmoothQuantModifier(Modifier): """ smoothing_strength: float = 0.5 - mappings: Optional[List[Tuple[List[str], str]]] = None + mappings: Optional[List[Sequence[List[str]|str]]] = None ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 1be8d2676..19326c3a9 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -369,7 +369,6 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module: """ if not layer_name: return module - print("CCC", layer_name) return attrgetter(layer_name)(module) From 4e96a7306e7dfd1de437ba23686e1012fcce5278 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 8 Dec 2025 13:27:57 -0500 Subject: [PATCH 19/25] formatting Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 1 - src/llmcompressor/modifiers/smoothquant/base.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 72526a95b..88cd6a34c 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -32,7 +32,6 @@ from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( - get_layer_by_name, get_module_to_name_dict, ) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 3eb4c66ba..6b7535e73 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Sequence +from typing import Callable, Dict, List, Optional, Sequence, Tuple import torch from compressed_tensors.utils import align_module_device, match_modules_set @@ -96,7 +96,7 @@ class SmoothQuantModifier(Modifier): """ smoothing_strength: float = 0.5 - mappings: Optional[List[Sequence[List[str]|str]]] = None + mappings: Optional[List[Sequence[List[str] | str]]] = None ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None From 5eb0de4b06c9352c02cfb9080a5d8f4221545db9 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 8 Dec 2025 16:35:30 -0500 Subject: [PATCH 20/25] wrestling with pydantic Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/smoothquant/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 6b7535e73..0118a9718 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from compressed_tensors.utils import align_module_device, match_modules_set @@ -96,7 +96,7 @@ class SmoothQuantModifier(Modifier): """ smoothing_strength: float = 0.5 - mappings: Optional[List[Sequence[List[str] | str]]] = None + mappings: Optional[List[Union[Tuple, List]]] = None ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None From 6d9f7731544d8af6eb809f776e5209e4a2e886da Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 8 Dec 2025 21:57:47 -0500 Subject: [PATCH 21/25] format Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/smoothquant/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 0118a9718..53e925fa3 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from compressed_tensors.utils import align_module_device, match_modules_set From 270662ffb6ffe58a5f29d35abcac24fd4dcffaa2 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 9 Dec 2025 09:41:26 -0500 Subject: [PATCH 22/25] Apply suggestion from @kylesayrs Co-authored-by: Kyle Sayers Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com> --- src/llmcompressor/utils/pytorch/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 19326c3a9..4bc7fe46f 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -372,7 +372,7 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module: return attrgetter(layer_name)(module) -def get_module_to_name_dict(model: Module) -> dict[Module:str]: +def get_module_to_name_dict(model: Module) -> dict[Module, str]: module_to_name = {} for name, module in model.named_modules(): if module in module_to_name: From 129c6602b2d5556ecde872e8326df9dd17a3fe80 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 9 Dec 2025 11:15:31 -0500 Subject: [PATCH 23/25] comments Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 22 ++++++++++++---------- src/llmcompressor/utils/pytorch/module.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 88cd6a34c..52d4455a3 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,6 +1,6 @@ import inspect from itertools import product -from typing import Literal +from typing import Literal, Iterator import torch from compressed_tensors.quantization import disable_quantization @@ -363,9 +363,8 @@ def _set_resolved_mappings(self, model: Module) -> None: continue - ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor_name, ancestor = get_lowest_ancestor_with_avoid( - ancestor_name, model, torch.nn.ModuleList + ancestor_name, ancestor = get_lowest_common_ancestor_with_avoid( + balance_names, model, torch.nn.ModuleList ) resolved_mappings.append( @@ -740,9 +739,10 @@ def _check_layers_are_compatible( return True -def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.ModuleList): +def get_lowest_common_ancestor_with_avoid(balance_names: Iterator[str], model: Module, avoid=torch.nn.ModuleList): """ - get lowest ancestor that is not the avoided class/type + Get the lowest ancestor that is not the avoided class/type. + see compressed_tensors.utils.get_lowest_common_ancestor_name for detail on case handling. NOTE: primarily used to exclude parents of type ModuleList, which don't play nicely with hooks because their forward method is never directly @@ -750,13 +750,15 @@ def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Modu are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 """ + ancestor_name = get_lowest_common_ancestor_name(balance_names) + while True: - if name == "": + if ancestor_name == "": return "", model - ancestor = model.get_submodule(name) + ancestor = model.get_submodule(ancestor_name) if not isinstance(ancestor, avoid): - return name, ancestor - name = ".".join(name.split(".")[:-1]) + return ancestor_name, ancestor + ancestor_name = ".".join(ancestor_name.split(".")[:-1]) def _pseudo_quantize_tensor( diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 4bc7fe46f..700d8bc83 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -376,7 +376,7 @@ def get_module_to_name_dict(model: Module) -> dict[Module, str]: module_to_name = {} for name, module in model.named_modules(): if module in module_to_name: - logger.info( + logger.warning( f"Warning, {name} and {module_to_name[module]} both " "share the same module, which can result in unexpected behavior" ) From 9a898d673d096f50ce12a31d172dcf9572d831e5 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 9 Dec 2025 16:06:35 -0500 Subject: [PATCH 24/25] format Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 52d4455a3..d1fd97d63 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,6 +1,6 @@ import inspect from itertools import product -from typing import Literal, Iterator +from typing import Iterator, Literal import torch from compressed_tensors.quantization import disable_quantization @@ -739,10 +739,13 @@ def _check_layers_are_compatible( return True -def get_lowest_common_ancestor_with_avoid(balance_names: Iterator[str], model: Module, avoid=torch.nn.ModuleList): +def get_lowest_common_ancestor_with_avoid( + balance_names: Iterator[str], model: Module, avoid=torch.nn.ModuleList +): """ Get the lowest ancestor that is not the avoided class/type. - see compressed_tensors.utils.get_lowest_common_ancestor_name for detail on case handling. + see compressed_tensors.utils.get_lowest_common_ancestor_name + for detail on case handling. NOTE: primarily used to exclude parents of type ModuleList, which don't play nicely with hooks because their forward method is never directly From b268cc5e330925399e366bf130e7187f982338e1 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 9 Dec 2025 16:10:33 -0500 Subject: [PATCH 25/25] rahul doesn't like my down arrow v Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index d1fd97d63..fbff5bab1 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -338,7 +338,7 @@ def _set_resolved_mappings(self, model: Module) -> None: smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layer) - # [[b00, b01, b02...], [b10, b11, b12,...], ...] v + # [[b00, b01, b02...], [b10, b11, b12,...], ...] ↓ # [b00, b01, b02, ..., b10, b11, b12, ...] balance_layers = tree_leaves(nested_balance_layers) balance_names = [