|
32 | 32 | from llmcompressor.utils.fsdp.helpers import get_fsdp_parent |
33 | 33 | from llmcompressor.utils.helpers import calibration_forward_context |
34 | 34 | from llmcompressor.utils.pytorch.module import ( |
35 | | - get_layer_by_name, |
36 | 35 | get_module_to_name_dict, |
37 | 36 | ) |
38 | 37 |
|
@@ -365,7 +364,10 @@ def _set_resolved_mappings(self, model: Module) -> None: |
365 | 364 |
|
366 | 365 | ancestor_name = get_lowest_common_ancestor_name(balance_names) |
367 | 366 | # no ModuleList ancestors |
368 | | - while not isinstance((ancestor := model.get_submodule(ancestor_name)), torch.nn.ModuleList): |
| 367 | + while not isinstance( |
| 368 | + (ancestor := model.get_submodule(ancestor_name)), |
| 369 | + torch.nn.ModuleList, |
| 370 | + ): |
369 | 371 | ancestor_name = ancestor_name.rsplit(".", 1)[0] |
370 | 372 |
|
371 | 373 | resolved_mappings.append( |
@@ -798,28 +800,3 @@ def _accumulate_mean( |
798 | 800 | new_count = prev_count + num_added |
799 | 801 |
|
800 | 802 | return (prev_sum + sum_added) / new_count, new_count |
801 | | - |
802 | | - |
803 | | -def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: |
804 | | - """ |
805 | | - Given a name and a model, finds lowest ancestor of |
806 | | - named module that's not a ModuleList |
807 | | - i.e. module_list.module_dict.module_list -> module_list.module_dict |
808 | | - i.e. module_list.module_dict -> module_list.module_dict |
809 | | - (self is an ancestor of self) |
810 | | -
|
811 | | - NOTE: This is needed because ModuleLists don't play |
812 | | - nicely with hooks because their forward method is never directly |
813 | | - called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts |
814 | | - are selected based on router output and their forward method is called. |
815 | | - https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 |
816 | | -
|
817 | | - Returns name of module and pointer to module |
818 | | - """ |
819 | | - while True: |
820 | | - if name == "": |
821 | | - return "", module |
822 | | - current_module = get_layer_by_name(name, module) |
823 | | - if not isinstance(current_module, torch.nn.ModuleList): |
824 | | - return name, current_module |
825 | | - name = ".".join(name.split(".")[:-1]) |
0 commit comments