Skip to content

Add MoE router z-loss for training stability (ST-MoE) #117

@amazloumi

Description

@amazloumi

What

KF has no router z-loss — the existing z_loss (training/loss.py) is the PaLM LM-head logit regularizer, unrelated to routing. MoE training is prone to router-logit growth → instability/expert collapse (top-1 at 410M collapses to a single expert; raising moe_aux_loss_weight doesn't fix it and degrades LM loss). The standard fix is ST-MoE's router z-loss, L_z = mean_token (logsumexp(router_logits))² (Zoph et al. 2022), used in Switch / Mixtral / Megatron-LM / DeepSpeed-MoE.

Add ModelConfig.moe_router_z_loss_weight: float = 0.0 (0 = off). Compute the z-loss in the routers (only when the weight > 0), expose it via MoEMLP.z_loss and Transformer.get_moe_router_z_loss() (mirroring aux_loss / get_moe_aux_loss), add weight × z_loss to the loss in scripts/train.py, and log moe/router_z_loss.

Scope

config/model.py (field + validation); model/router.py (both routers); model/moe.py (z_loss property); model/transformer.py (get_moe_router_z_loss); scripts/train.py (add to loss + metric). Standard routers only — MoMa unaffected.

Backward compatibility

Default 0.0 ⇒ z-loss never computed or added ⇒ training/outputs/grads identical. z_loss is a plain attribute (like aux_loss), not a register_buffer/parameter ⇒ not in state_dict ⇒ checkpoints unaffected.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions