Skip to content

Commit f757a67

Browse files
committed
add regenerative regularization from https://arxiv.org/abs/2308.11958
1 parent 5d7cd48 commit f757a67

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,12 @@ opt.zero_grad()
6262
url = {https://api.semanticscholar.org/CorpusID:270123846}
6363
}
6464
```
65+
66+
```bibtex
67+
@misc{kumar2024maintaining,
68+
title={Maintaining Plasticity in Continual Learning via Regenerative Regularization},
69+
author={Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
70+
year={2024},
71+
url={https://openreview.net/forum?id=lyoOWX0e0O}
72+
}
73+
```

grokfast_pytorch/grokfast.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
betas: Tuple[float, float] = (0.9, 0.99),
2020
weight_decay = 0.,
2121
eps = 1e-8,
22+
regen_reg_rate = 0.,
2223
grokfast = True,
2324
grokfast_alpha = 0.98,
2425
grokfast_lamb = 2.,
@@ -28,7 +29,9 @@ def __init__(
2829
assert lr > 0.
2930
assert all([0. <= beta <= 1. for beta in betas])
3031
assert weight_decay >= 0.
32+
assert regen_reg_rate >= 0.
3133
assert eps > 0.
34+
assert not (weight_decay >0. and regen_reg_rate > 0.), 'weight decay and regenerative regularization cannot be used together'
3235

3336
# in order for fair comparison
3437
# reduce the learning rate by a factor of (1 + grokfast_lamb)
@@ -43,6 +46,7 @@ def __init__(
4346
betas = betas,
4447
eps = eps,
4548
weight_decay = weight_decay,
49+
regen_reg_rate = regen_reg_rate,
4650
grokfast = grokfast,
4751
grokfast_alpha = grokfast_alpha,
4852
grokfast_lamb = grokfast_lamb,
@@ -79,20 +83,31 @@ def step(
7983
for group in self.param_groups:
8084
for p in filter(lambda p: exists(p.grad), group['params']):
8185

82-
grad, lr, wd, beta1, beta2, eps, grokfast, grokfast_after_step, alpha, lamb, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['eps'], group['grokfast'], group['grokfast_after_step'], group['grokfast_alpha'], group['grokfast_lamb'], self.state[p], self._init_lr
86+
grad, lr, wd, regen_rate, beta1, beta2, eps, grokfast, grokfast_after_step, alpha, lamb, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['eps'], group['grokfast'], group['grokfast_after_step'], group['grokfast_alpha'], group['grokfast_lamb'], self.state[p], self._init_lr
8387

8488
# decoupled weight decay
8589

8690
if wd > 0.:
8791
p.mul_(1. - lr / init_lr * wd)
8892

93+
# regenerative regularization - ICLR 2024
94+
# https://openreview.net/forum?id=lyoOWX0e0O
95+
96+
if regen_rate > 0. and 'param_init' in state:
97+
param_init = state['param_init']
98+
99+
p.lerp_(param_init, lr / init_lr * regen_rate)
100+
89101
# init state if needed
90102

91103
if len(state) == 0:
92104
state['steps'] = 0
93105
state['exp_avg'] = torch.zeros_like(grad)
94106
state['exp_avg_sq'] = torch.zeros_like(grad)
95107

108+
if regen_rate > 0.:
109+
state['param_init'] = p.data.clone()
110+
96111
# get some of the states
97112

98113
exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "grokfast-pytorch"
3-
version = "0.0.5"
3+
version = "0.0.7"
44
description = "Grokfast"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)