Skip to content

Commit 1d9e344

Browse files
committed
add ability for the grokfast optimizer to automatically do spectral entropy reg pre step
1 parent f757a67 commit 1d9e344

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,14 @@ opt.zero_grad()
7171
url={https://openreview.net/forum?id=lyoOWX0e0O}
7272
}
7373
```
74+
75+
```bibtex
76+
@inproceedings{anonymous2024the,
77+
title = {The Complexity Dynamics of Grokking},
78+
author = {Anonymous},
79+
booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
80+
year = {2024},
81+
url = {https://openreview.net/forum?id=07N9jCfIE4},
82+
note = {under review}
83+
}
84+
```

grokfast_pytorch/grokfast.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
2-
from typing import Tuple, Callable
2+
from typing import Callable
3+
4+
from functools import partial
35

46
import torch
57
from torch.optim.optimizer import Optimizer
@@ -9,22 +11,50 @@
911
def exists(val):
1012
return val is not None
1113

14+
# tensor helpers
15+
16+
def log(t, eps = 1e-20):
17+
return t.clamp(min = eps).log()
18+
19+
def entropy(prob):
20+
return (-prob * log(prob)).sum(dim = -1)
21+
22+
def spectral_entropy_reg_loss_hook(optimizer, weight, *args, **kwargs):
23+
loss = torch.tensor(0.).requires_grad_()
24+
25+
for param_group in optimizer.param_groups:
26+
for param in param_group['params']:
27+
if param.ndim < 2:
28+
continue
29+
30+
*_, row, col = param.shape
31+
reshaped_param = param.reshape(-1, row, col)
32+
33+
singular_values = torch.linalg.svdvals(reshaped_param)
34+
spectral_prob = singular_values.softmax(dim = -1)
35+
spectral_entropy = entropy(spectral_prob).sum()
36+
loss = loss + spectral_entropy
37+
38+
(loss * weight).backward()
39+
1240
# class
1341

1442
class GrokFastAdamW(Optimizer):
1543
def __init__(
1644
self,
1745
params,
1846
lr = 1e-4,
19-
betas: Tuple[float, float] = (0.9, 0.99),
47+
betas: tuple[float, float] = (0.9, 0.99),
2048
weight_decay = 0.,
2149
eps = 1e-8,
2250
regen_reg_rate = 0.,
2351
grokfast = True,
2452
grokfast_alpha = 0.98,
2553
grokfast_lamb = 2.,
2654
grokfast_after_step = 0,
27-
normalize_lr = True
55+
normalize_lr = True,
56+
add_spectral_entropy_reg = False,
57+
spectral_entropy_reg_weight = 0.1
2858
):
2959
assert lr > 0.
3060
assert all([0. <= beta <= 1. for beta in betas])
@@ -55,6 +85,14 @@ def __init__(
5585

5686
super().__init__(params, defaults)
5787

88+
# maybe spectral entropy reg
89+
# https://openreview.net/forum?id=07N9jCfIE4
90+
91+
if not add_spectral_entropy_reg:
92+
return
93+
94+
self.register_step_pre_hook(partial(spectral_entropy_reg_loss_hook, self, spectral_entropy_reg_weight))
95+
5896
def turn_on_grokfast(self):
5997
for group in self.param_groups:
6098
group['grokfast'] = True

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "grokfast-pytorch"
3-
version = "0.0.7"
3+
version = "0.0.9"
44
description = "Grokfast"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)