Fix MoE expert outputs not weighted by gate probabilities#201
Fix MoE expert outputs not weighted by gate probabilities#201MRPRESIDENT66 wants to merge 2 commits into
Conversation
The MoEBlock computes softmax gate probabilities for each expert but only uses them for top-k selection. The expert outputs are summed without being scaled by their gate weights, which diverges from the standard MoE formulation (result = Σ gate_weight_i × expert_i(input)). This matches the reference HuggingFace Transformers implementation in MixtralSparseMoeBlock where each expert's output is multiplied by its routing weight before accumulation.
|
There are no unit tests here. I will ty with mixtral and see how it goes |
|
Changes output to useless: |
|
Moving the scale BEFORE the results doesnt blow up the output. |
Without re-normalization, the selected top-k gate weights (from a softmax over all N experts) sum to less than 1, causing expert outputs to be scaled down proportionally. This shrinks the MoE output magnitude and breaks the residual stream statistics the model was trained with. Re-normalize the selected weights to sum to 1 before applying them, matching the reference HuggingFace Mixtral implementation.
|
|
Hi, sorry for not being able to test locally with a full Mixtral model. Regarding your suggestion of moving scale before the matmul — I believe it has no effect because dotProductChunk uses result.set(...) which overwrites moeResult entirely, so any scaling applied before it gets lost. I also realized my first version of this fix was incomplete. Multiplying by the raw softmax probabilities without re-normalizing causes the output to shrink, because the top-k weights (e.g. top-2 out of 8 experts) sum to less than 1. This is why the output broke in your test. The updated fix adds a re-normalization step before applying the weights, ensuring the selected top-k weights sum to 1 |
|
Send the PR here. edwardcapriolo/deliverance#87 I dont have merge ability for this project and it is not being maintained ATM |
The MixtureOfExpertsBlock computes softmax gate probabilities but only uses them for top-k selection. Expert outputs are summed without being scaled by their gate weights. This fix re-normalizes the selected top-k weights to sum to 1, then scales each expert output by its normalized gate weight before accumulation, matching the HuggingFace Mixtral reference implementation. Fixes: edwardcapriolo#87 See also: tjake/Jlama#201
* Fix MoE expert outputs not weighted by gate probabilities The MixtureOfExpertsBlock computes softmax gate probabilities but only uses them for top-k selection. Expert outputs are summed without being scaled by their gate weights. This fix re-normalizes the selected top-k weights to sum to 1, then scales each expert output by its normalized gate weight before accumulation, matching the HuggingFace Mixtral reference implementation. Fixes: #87 See also: tjake/Jlama#201 * Fix along with pr from mr-pres --------- Co-authored-by: MRPRESIDENT66 <jinmingyijack@163.com>
* Fix MoE expert outputs not weighted by gate probabilities The MixtureOfExpertsBlock computes softmax gate probabilities but only uses them for top-k selection. Expert outputs are summed without being scaled by their gate weights. This fix re-normalizes the selected top-k weights to sum to 1, then scales each expert output by its normalized gate weight before accumulation, matching the HuggingFace Mixtral reference implementation. Fixes: #87 See also: tjake/Jlama#201 * Fix along with pr from mr-pres --------- Co-authored-by: MRPRESIDENT66 <jinmingyijack@163.com>
* Fix MoE expert outputs not weighted by gate probabilities The MixtureOfExpertsBlock computes softmax gate probabilities but only uses them for top-k selection. Expert outputs are summed without being scaled by their gate weights. This fix re-normalizes the selected top-k weights to sum to 1, then scales each expert output by its normalized gate weight before accumulation, matching the HuggingFace Mixtral reference implementation. Fixes: #87 See also: tjake/Jlama#201 * Fix along with pr from mr-pres * changes for removal --------- Co-authored-by: MRPRESIDENT66 <jinmingyijack@163.com>
The MoEBlock computes softmax gate probabilities for each expert but only uses them for top-k selection. The expert outputs are summed without being scaled by their gate weights, which diverges from the standard MoE formulation (result = Σ gate_weight_i × expert_i(input)).
This matches the reference HuggingFace Transformers implementation in MixtralSparseMoeBlock where each expert's output is multiplied by its routing weight before accumulation.