@@ -23,10 +23,12 @@ def spectral_entropy_reg_loss_hook(optimizer, *args, **kwargs):
2323 loss = torch .tensor (0. ).requires_grad_ ()
2424
2525 for param_group in optimizer .param_groups :
26- if not param ['add_spectral_entropy_reg' ]:
26+ if not param_group ['add_spectral_entropy_reg' ]:
2727 continue
2828
29- weight = param ['spectral_entropy_reg_weight' ]
29+ weight = param_group ['spectral_entropy_reg_weight' ]
30+ use_svd_lowrank = param_group ['use_svd_lowrank' ]
31+ svd_lowrank_kwargs = param_group ['svd_lowrank_kwargs' ]
3032
3133 for param in param_group ['params' ]:
3234 if param .ndim < 2 :
@@ -35,7 +37,11 @@ def spectral_entropy_reg_loss_hook(optimizer, *args, **kwargs):
3537 * _ , row , col = param .shape
3638 reshaped_param = param .reshape (- 1 , row , col )
3739
38- singular_values = torch .linalg .svdvals (reshaped_param )
40+ if not use_svd_lowrank :
41+ singular_values = torch .linalg .svdvals (reshaped_param )
42+ else :
43+ _ , singular_values , _ = torch .svd_lowrank (reshaped_param , ** svd_lowrank_kwargs )
44+
3945 spectral_prob = singular_values .softmax (dim = - 1 )
4046 spectral_entropy = entropy (spectral_prob ).sum ()
4147 loss = loss + spectral_entropy
@@ -59,7 +65,9 @@ def __init__(
5965 grokfast_after_step = 0 ,
6066 normalize_lr = True ,
6167 add_spectral_entropy_reg = False ,
62- spectral_entropy_reg_weight = 0.1
68+ spectral_entropy_reg_weight = 0.1 ,
69+ use_svd_lowrank = False ,
70+ svd_lowrank_kwargs : dict = dict ()
6371 ):
6472 assert lr > 0.
6573 assert all ([0. <= beta <= 1. for beta in betas ])
@@ -87,7 +95,9 @@ def __init__(
8795 grokfast_lamb = grokfast_lamb ,
8896 grokfast_after_step = grokfast_after_step ,
8997 add_spectral_entropy_reg = add_spectral_entropy_reg ,
90- spectral_entropy_reg_weight = spectral_entropy_reg_weight
98+ spectral_entropy_reg_weight = spectral_entropy_reg_weight ,
99+ use_svd_lowrank = use_svd_lowrank ,
100+ svd_lowrank_kwargs = svd_lowrank_kwargs
91101 )
92102
93103 super ().__init__ (params , defaults )
0 commit comments