Skip to content

Commit e963ef7

Browse files
committed
fix spectral entropy pathway and add option to use svd_lowrank
1 parent eadb5d7 commit e963ef7

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ opt.zero_grad()
6565

6666
```bibtex
6767
@misc{kumar2024maintaining,
68-
title={Maintaining Plasticity in Continual Learning via Regenerative Regularization},
69-
author={Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
70-
year={2024},
71-
url={https://openreview.net/forum?id=lyoOWX0e0O}
68+
title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization},
69+
author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
70+
year = {2024},
71+
url = {https://openreview.net/forum?id=lyoOWX0e0O}
7272
}
7373
```
7474

grokfast_pytorch/grokfast.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def spectral_entropy_reg_loss_hook(optimizer, *args, **kwargs):
2323
loss = torch.tensor(0.).requires_grad_()
2424

2525
for param_group in optimizer.param_groups:
26-
if not param['add_spectral_entropy_reg']:
26+
if not param_group['add_spectral_entropy_reg']:
2727
continue
2828

29-
weight = param['spectral_entropy_reg_weight']
29+
weight = param_group['spectral_entropy_reg_weight']
30+
use_svd_lowrank = param_group['use_svd_lowrank']
31+
svd_lowrank_kwargs = param_group['svd_lowrank_kwargs']
3032

3133
for param in param_group['params']:
3234
if param.ndim < 2:
@@ -35,7 +37,11 @@ def spectral_entropy_reg_loss_hook(optimizer, *args, **kwargs):
3537
*_, row, col = param.shape
3638
reshaped_param = param.reshape(-1, row, col)
3739

38-
singular_values = torch.linalg.svdvals(reshaped_param)
40+
if not use_svd_lowrank:
41+
singular_values = torch.linalg.svdvals(reshaped_param)
42+
else:
43+
_, singular_values, _ = torch.svd_lowrank(reshaped_param, **svd_lowrank_kwargs)
44+
3945
spectral_prob = singular_values.softmax(dim = -1)
4046
spectral_entropy = entropy(spectral_prob).sum()
4147
loss = loss + spectral_entropy
@@ -59,7 +65,9 @@ def __init__(
5965
grokfast_after_step = 0,
6066
normalize_lr = True,
6167
add_spectral_entropy_reg = False,
62-
spectral_entropy_reg_weight = 0.1
68+
spectral_entropy_reg_weight = 0.1,
69+
use_svd_lowrank = False,
70+
svd_lowrank_kwargs: dict = dict()
6371
):
6472
assert lr > 0.
6573
assert all([0. <= beta <= 1. for beta in betas])
@@ -87,7 +95,9 @@ def __init__(
8795
grokfast_lamb = grokfast_lamb,
8896
grokfast_after_step = grokfast_after_step,
8997
add_spectral_entropy_reg = add_spectral_entropy_reg,
90-
spectral_entropy_reg_weight = spectral_entropy_reg_weight
98+
spectral_entropy_reg_weight = spectral_entropy_reg_weight,
99+
use_svd_lowrank = use_svd_lowrank,
100+
svd_lowrank_kwargs = svd_lowrank_kwargs
91101
)
92102

93103
super().__init__(params, defaults)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "grokfast-pytorch"
3-
version = "0.0.10"
3+
version = "0.0.11"
44
description = "Grokfast"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)