@@ -19,6 +19,7 @@ def __init__(
1919 betas : Tuple [float , float ] = (0.9 , 0.99 ),
2020 weight_decay = 0. ,
2121 eps = 1e-8 ,
22+ regen_reg_rate = 0. ,
2223 grokfast = True ,
2324 grokfast_alpha = 0.98 ,
2425 grokfast_lamb = 2. ,
@@ -28,7 +29,9 @@ def __init__(
2829 assert lr > 0.
2930 assert all ([0. <= beta <= 1. for beta in betas ])
3031 assert weight_decay >= 0.
32+ assert regen_reg_rate >= 0.
3133 assert eps > 0.
34+ assert not (weight_decay > 0. and regen_reg_rate > 0. ), 'weight decay and regenerative regularization cannot be used together'
3235
3336 # in order for fair comparison
3437 # reduce the learning rate by a factor of (1 + grokfast_lamb)
@@ -43,6 +46,7 @@ def __init__(
4346 betas = betas ,
4447 eps = eps ,
4548 weight_decay = weight_decay ,
49+ regen_reg_rate = regen_reg_rate ,
4650 grokfast = grokfast ,
4751 grokfast_alpha = grokfast_alpha ,
4852 grokfast_lamb = grokfast_lamb ,
@@ -79,20 +83,31 @@ def step(
7983 for group in self .param_groups :
8084 for p in filter (lambda p : exists (p .grad ), group ['params' ]):
8185
82- grad , lr , wd , beta1 , beta2 , eps , grokfast , grokfast_after_step , alpha , lamb , state , init_lr = p .grad , group ['lr' ], group ['weight_decay' ], * group ['betas' ], group ['eps' ], group ['grokfast' ], group ['grokfast_after_step' ], group ['grokfast_alpha' ], group ['grokfast_lamb' ], self .state [p ], self ._init_lr
86+ grad , lr , wd , regen_rate , beta1 , beta2 , eps , grokfast , grokfast_after_step , alpha , lamb , state , init_lr = p .grad , group ['lr' ], group ['weight_decay' ], group [ 'regen_reg_rate ' ], * group ['betas' ], group ['eps' ], group ['grokfast' ], group ['grokfast_after_step' ], group ['grokfast_alpha' ], group ['grokfast_lamb' ], self .state [p ], self ._init_lr
8387
8488 # decoupled weight decay
8589
8690 if wd > 0. :
8791 p .mul_ (1. - lr / init_lr * wd )
8892
93+ # regenerative regularization - ICLR 2024
94+ # https://openreview.net/forum?id=lyoOWX0e0O
95+
96+ if regen_rate > 0. and 'param_init' in state :
97+ param_init = state ['param_init' ]
98+
99+ p .lerp_ (param_init , lr / init_lr * regen_rate )
100+
89101 # init state if needed
90102
91103 if len (state ) == 0 :
92104 state ['steps' ] = 0
93105 state ['exp_avg' ] = torch .zeros_like (grad )
94106 state ['exp_avg_sq' ] = torch .zeros_like (grad )
95107
108+ if regen_rate > 0. :
109+ state ['param_init' ] = p .data .clone ()
110+
96111 # get some of the states
97112
98113 exp_avg , exp_avg_sq , steps = state ['exp_avg' ], state ['exp_avg_sq' ], state ['steps' ]
0 commit comments