Skip to content

fix: layers_to_transform now correctly matches layer index on MoE models#3028

Open
Mr-Neutr0n wants to merge 1 commit intohuggingface:mainfrom
Mr-Neutr0n:fix/layers-to-transform-moe-expert-index
Open

fix: layers_to_transform now correctly matches layer index on MoE models#3028
Mr-Neutr0n wants to merge 1 commit intohuggingface:mainfrom
Mr-Neutr0n:fix/layers-to-transform-moe-expert-index

Conversation

@Mr-Neutr0n
Copy link

Summary

Fix layers_to_transform incorrectly matching MoE expert indices instead of layer indices.

Problem

When using layers_to_transform on MoE (Mixture of Experts) models, modules under paths like model.layers.<L>.mlp.experts.<E>.* were incorrectly filtered based on the expert index <E> instead of the layer index <L>.

Example:
For key model.layers.1.mlp.experts.0.up_proj:

  • Expected: Match layer index 1
  • Actual (before fix): Matched expert index 0

This caused layers_to_transform=[1] to incorrectly exclude model.layers.1.mlp.experts.0.up_proj (because 0 != 1) and incorrectly include model.layers.2.mlp.experts.1.up_proj (because 1 == 1).

Root Cause

The regex in check_target_module_exists() used greedy .* which matched as much as possible, leaving only the LAST digit in the path (the expert index) for the capture group:

layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key)  # Greedy

Solution

Changed to non-greedy .*? which matches as little as possible, capturing the FIRST digit in the path (the layer index):

layer_index = re.match(r".*?\.[^.]*\.(\d+)\.", key)  # Non-greedy

Test Cases Added

Added 7 test cases covering MoE scenarios:

  • With explicit layers_pattern
  • Without layers_pattern (default behavior)

Fixes #3016

…t index

When using layers_to_transform on MoE models, the regex was incorrectly
matching the LAST digit in the module path (expert index) instead of the
FIRST digit (layer index).

For example, with key 'model.layers.1.mlp.experts.0.up_proj':
- Old behavior: matched 0 (expert index)
- Fixed behavior: matches 1 (layer index)

The fix uses non-greedy .*? instead of greedy .* in the default regex
pattern, so it finds the first occurrence of the digit pattern.

Added test cases for MoE scenarios to prevent regression.

Fixes huggingface#3016
@Mr-Neutr0n
Copy link
Author

Friendly bump! Let me know if there's anything I should update or improve to help move this forward.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Feb 13, 2026

Thanks for this PR. Please check my comment here.

Edit: Add MichalMraz as co-author here.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I found that the code path for layers_pattern != None is not covered by this change. Let's update it too for consistency. More details in my comment.

Note: The failing CI is unrelated and can be ignored.

Comment on lines +133 to +136
("model.layers.1.mlp.experts.0.up_proj", ["up_proj"], [1], ["layers"], True),
("model.layers.1.mlp.experts.0.up_proj", ["up_proj"], [0], ["layers"], False), # expert 0, but layer 1
("model.layers.1.mlp.experts.1.up_proj", ["up_proj"], [1], ["layers"], True), # expert 1, layer 1
("model.layers.2.mlp.experts.1.up_proj", ["up_proj"], [1], ["layers"], False), # layer 2, not 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests don't really test the mentioned bug. E.g. if the key is "model.layers.1.mlp.experts.0.up_proj", with the layers_pattern being ["layers"], there is just a single match, namely "layers.1". "experts.0" would never be matched in this situation, even with the current code from main. In fact, that code path is never reached, as we land here:

else:
layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern
for pattern in layers_pattern:
layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key)
if layer_index is not None:
break

At this stage, we still match the last number, but since there is only one match, it doesn't matter. With the change from this PR, however, I would always expect the first number to match to be consistent with the other code path. Therefore, I would suggest to change that regex too, and to update the examples to really test for this.

@BenjaminBossan
Copy link
Member

ping @Mr-Neutr0n

@Mr-Neutr0n
Copy link
Author

pong

@Mr-Neutr0n
Copy link
Author

Will address the comments. Out for an emergency the past few days. Thanks for considering!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

layers_to_transform incorrectly matches MoE expert indices (should match layer index)

3 participants