Skip to content

Commit 86e9d21

Browse files
committed
Add nadamuon trained dwee / dpwee vit weights. Add comment to muon impl
1 parent c263de1 commit 86e9d21

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

timm/models/vision_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2703,12 +2703,18 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
27032703
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
27042704
hf_hub_id='timm/',
27052705
input_size=(3, 256, 256), crop_pct=0.95),
2706+
'vit_dwee_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
2707+
hf_hub_id='timm/',
2708+
input_size=(3, 256, 256), crop_pct=0.95),
27062709
'vit_dwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
27072710
hf_hub_id='timm/',
27082711
input_size=(3, 256, 256), crop_pct=0.95),
27092712
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
27102713
hf_hub_id='timm/',
27112714
input_size=(3, 256, 256), crop_pct=0.95),
2715+
'vit_dpwee_patch16_reg1_gap_256.sbb_nadamuon_in1k': _cfg(
2716+
hf_hub_id='timm/',
2717+
input_size=(3, 256, 256), crop_pct=0.95),
27122718
'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
27132719
hf_hub_id='timm/',
27142720
input_size=(3, 256, 256), crop_pct=0.95),

timm/optim/muon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,7 @@ def _single_tensor_adamuon(
602602
# RMS-aligned rescaling: normalize by update norm, then scale by shape factor
603603
# Used by AdaMuon paper approach (match_rms_adamw), not by μP approach (rms_to_rms)
604604
if use_rms_norm:
605+
# eq(8) in AdaMuon paper, 0.2 / RMS(update) = 0.2 * sqrt(ndim) / frob(update)
605606
update_norm = update_adaptive.norm().add_(eps)
606607
update_adaptive = update_adaptive / update_norm
607608

0 commit comments

Comments
 (0)