Skip to content

Fix MoE expert outputs not weighted by gate probabilities#201

Open
MRPRESIDENT66 wants to merge 2 commits into
tjake:mainfrom
MRPRESIDENT66:fix/moe-gate-weight-scaling
Open

Fix MoE expert outputs not weighted by gate probabilities#201
MRPRESIDENT66 wants to merge 2 commits into
tjake:mainfrom
MRPRESIDENT66:fix/moe-gate-weight-scaling

Conversation

@MRPRESIDENT66

Copy link
Copy Markdown

The MoEBlock computes softmax gate probabilities for each expert but only uses them for top-k selection. The expert outputs are summed without being scaled by their gate weights, which diverges from the standard MoE formulation (result = Σ gate_weight_i × expert_i(input)).

This matches the reference HuggingFace Transformers implementation in MixtralSparseMoeBlock where each expert's output is multiplied by its routing weight before accumulation.

The MoEBlock computes softmax gate probabilities for each expert but
only uses them for top-k selection. The expert outputs are summed
without being scaled by their gate weights, which diverges from the
standard MoE formulation (result = Σ gate_weight_i × expert_i(input)).

This matches the reference HuggingFace Transformers implementation in
MixtralSparseMoeBlock where each expert's output is multiplied by its
routing weight before accumulation.
@edwardcapriolo

Copy link
Copy Markdown
Contributor

There are no unit tests here. I will ty with mixtral and see how it goes

@edwardcapriolo

edwardcapriolo commented Apr 18, 2026

Copy link
Copy Markdown
Contributor

Changes output to useless:

Was:
Edward Capriolo is a writer, director, and producer of theatrical films, television, and documentaries. He is also a professor of film and media studies at the University of North Carolina, Wilmington. He is the founder and director of the Wilmington Film Festival, a juried competition for independent filmmakers. He is also the founder and director of the Wilmington Film Commission, a non-profit organization that supports the local film industry. He>

[main] INFO io.teknek.deliverance.model.AbstractModel - Tensor provider = Native SIMD Operations, parallelSplitSize = 32 
[main] INFO io.teknek.deliverance.model.AbstractModel - Model type = Q4, Working memory type = F32, Quantized memory type = I8
 
1
0
0
0
0
0
0
0
0

 mvn test -Dtest=MixralIT

@edwardcapriolo

Copy link
Copy Markdown
Contributor

Moving the scale BEFORE the results doesnt blow up the output.

 model.configurableTensorProvider.get().scale(gateWeight, moeResult, 0, model.config.embeddingLength);        
                                                                                                                                 
                    // matmul the projection and sum into result                                                                 
                    try (AbstractTensor bufq = model.maybeQuantize(buf)) { 
                    ```
                    
                   But in my limited testing the result is unchanged

Without re-normalization, the selected top-k gate weights (from a softmax
over all N experts) sum to less than 1, causing expert outputs to be scaled
down proportionally. This shrinks the MoE output magnitude and breaks the
residual stream statistics the model was trained with.

Re-normalize the selected weights to sum to 1 before applying them, matching
the reference HuggingFace Mixtral implementation.
@MRPRESIDENT66

Copy link
Copy Markdown
Author

Moving the scale BEFORE the results doesnt blow up the output.

 model.configurableTensorProvider.get().scale(gateWeight, moeResult, 0, model.config.embeddingLength);        
                                                                                                                                 
                    // matmul the projection and sum into result                                                                 
                    try (AbstractTensor bufq = model.maybeQuantize(buf)) { 
                    ```
                    
                   But in my limited testing the result is unchanged

@MRPRESIDENT66

Copy link
Copy Markdown
Author

Hi, sorry for not being able to test locally with a full Mixtral model.

Regarding your suggestion of moving scale before the matmul — I believe it has no effect because dotProductChunk uses result.set(...) which overwrites moeResult entirely, so any scaling applied before it gets lost.

I also realized my first version of this fix was incomplete. Multiplying by the raw softmax probabilities without re-normalizing causes the output to shrink, because the top-k weights (e.g. top-2 out of 8 experts) sum to less than 1. This is why the output broke in your test.

The updated fix adds a re-normalization step before applying the weights, ensuring the selected top-k weights sum to 1

@edwardcapriolo

Copy link
Copy Markdown
Contributor

Send the PR here. edwardcapriolo/deliverance#87

I dont have merge ability for this project and it is not being maintained ATM

MRPRESIDENT66 added a commit to MRPRESIDENT66/deliverance that referenced this pull request Apr 18, 2026
The MixtureOfExpertsBlock computes softmax gate probabilities but only
uses them for top-k selection. Expert outputs are summed without being
scaled by their gate weights.

This fix re-normalizes the selected top-k weights to sum to 1, then
scales each expert output by its normalized gate weight before
accumulation, matching the HuggingFace Mixtral reference implementation.

Fixes: edwardcapriolo#87
See also: tjake/Jlama#201
edwardcapriolo added a commit to edwardcapriolo/deliverance that referenced this pull request Apr 19, 2026
* Fix MoE expert outputs not weighted by gate probabilities

The MixtureOfExpertsBlock computes softmax gate probabilities but only
uses them for top-k selection. Expert outputs are summed without being
scaled by their gate weights.

This fix re-normalizes the selected top-k weights to sum to 1, then
scales each expert output by its normalized gate weight before
accumulation, matching the HuggingFace Mixtral reference implementation.

Fixes: #87
See also: tjake/Jlama#201

* Fix along with pr from mr-pres

---------

Co-authored-by: MRPRESIDENT66 <jinmingyijack@163.com>
edwardcapriolo added a commit to edwardcapriolo/deliverance that referenced this pull request Apr 19, 2026
* Fix MoE expert outputs not weighted by gate probabilities

The MixtureOfExpertsBlock computes softmax gate probabilities but only
uses them for top-k selection. Expert outputs are summed without being
scaled by their gate weights.

This fix re-normalizes the selected top-k weights to sum to 1, then
scales each expert output by its normalized gate weight before
accumulation, matching the HuggingFace Mixtral reference implementation.

Fixes: #87
See also: tjake/Jlama#201

* Fix along with pr from mr-pres

---------

Co-authored-by: MRPRESIDENT66 <jinmingyijack@163.com>
edwardcapriolo added a commit to edwardcapriolo/deliverance that referenced this pull request Apr 19, 2026
* Fix MoE expert outputs not weighted by gate probabilities

The MixtureOfExpertsBlock computes softmax gate probabilities but only
uses them for top-k selection. Expert outputs are summed without being
scaled by their gate weights.

This fix re-normalizes the selected top-k weights to sum to 1, then
scales each expert output by its normalized gate weight before
accumulation, matching the HuggingFace Mixtral reference implementation.

Fixes: #87
See also: tjake/Jlama#201

* Fix along with pr from mr-pres

* changes for removal

---------

Co-authored-by: MRPRESIDENT66 <jinmingyijack@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants