What
KF has no router z-loss — the existing z_loss (training/loss.py) is the PaLM LM-head logit regularizer, unrelated to routing. MoE training is prone to router-logit growth → instability/expert collapse (top-1 at 410M collapses to a single expert; raising moe_aux_loss_weight doesn't fix it and degrades LM loss). The standard fix is ST-MoE's router z-loss, L_z = mean_token (logsumexp(router_logits))² (Zoph et al. 2022), used in Switch / Mixtral / Megatron-LM / DeepSpeed-MoE.
Add ModelConfig.moe_router_z_loss_weight: float = 0.0 (0 = off). Compute the z-loss in the routers (only when the weight > 0), expose it via MoEMLP.z_loss and Transformer.get_moe_router_z_loss() (mirroring aux_loss / get_moe_aux_loss), add weight × z_loss to the loss in scripts/train.py, and log moe/router_z_loss.
Scope
config/model.py (field + validation); model/router.py (both routers); model/moe.py (z_loss property); model/transformer.py (get_moe_router_z_loss); scripts/train.py (add to loss + metric). Standard routers only — MoMa unaffected.
Backward compatibility
Default 0.0 ⇒ z-loss never computed or added ⇒ training/outputs/grads identical. z_loss is a plain attribute (like aux_loss), not a register_buffer/parameter ⇒ not in state_dict ⇒ checkpoints unaffected.
What
KF has no router z-loss — the existing
z_loss(training/loss.py) is the PaLM LM-head logit regularizer, unrelated to routing. MoE training is prone to router-logit growth → instability/expert collapse (top-1 at 410M collapses to a single expert; raisingmoe_aux_loss_weightdoesn't fix it and degrades LM loss). The standard fix is ST-MoE's router z-loss,L_z = mean_token (logsumexp(router_logits))²(Zoph et al. 2022), used in Switch / Mixtral / Megatron-LM / DeepSpeed-MoE.Add
ModelConfig.moe_router_z_loss_weight: float = 0.0(0 = off). Compute the z-loss in the routers (only when the weight > 0), expose it viaMoEMLP.z_lossandTransformer.get_moe_router_z_loss()(mirroringaux_loss/get_moe_aux_loss), addweight × z_lossto the loss inscripts/train.py, and logmoe/router_z_loss.Scope
config/model.py(field + validation);model/router.py(both routers);model/moe.py(z_lossproperty);model/transformer.py(get_moe_router_z_loss);scripts/train.py(add to loss + metric). Standard routers only — MoMa unaffected.Backward compatibility
Default
0.0⇒ z-loss never computed or added ⇒ training/outputs/grads identical.z_lossis a plain attribute (likeaux_loss), not aregister_buffer/parameter ⇒ not instate_dict⇒ checkpoints unaffected.