@@ -19,10 +19,15 @@ def log(t, eps = 1e-20):
1919def entropy (prob ):
2020 return (- prob * log (prob )).sum (dim = - 1 )
2121
22- def spectral_entropy_reg_loss_hook (optimizer , weight , * args , ** kwargs ):
22+ def spectral_entropy_reg_loss_hook (optimizer , * args , ** kwargs ):
2323 loss = torch .tensor (0. ).requires_grad_ ()
2424
2525 for param_group in optimizer .param_groups :
26+ if not param ['add_spectral_entropy_reg' ]:
27+ continue
28+
29+ weight = param ['spectral_entropy_reg_weight' ]
30+
2631 for param in param_group ['params' ]:
2732 if param .ndim < 2 :
2833 continue
@@ -80,7 +85,9 @@ def __init__(
8085 grokfast = grokfast ,
8186 grokfast_alpha = grokfast_alpha ,
8287 grokfast_lamb = grokfast_lamb ,
83- grokfast_after_step = grokfast_after_step
88+ grokfast_after_step = grokfast_after_step ,
89+ add_spectral_entropy_reg = add_spectral_entropy_reg ,
90+ spectral_entropy_reg_weight = spectral_entropy_reg_weight
8491 )
8592
8693 super ().__init__ (params , defaults )
@@ -91,7 +98,7 @@ def __init__(
9198 if not add_spectral_entropy_reg :
9299 return
93100
94- self .register_step_pre_hook (partial (spectral_entropy_reg_loss_hook , self , spectral_entropy_reg_weight ))
101+ self .register_step_pre_hook (partial (spectral_entropy_reg_loss_hook , self ))
95102
96103 def turn_on_grokfast (self ):
97104 for group in self .param_groups :
0 commit comments