Skip to content

Commit eadb5d7

Browse files
committed
allow for customizing which parameters get the spectral entropy reg
1 parent 1d9e344 commit eadb5d7

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

grokfast_pytorch/grokfast.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ def log(t, eps = 1e-20):
1919
def entropy(prob):
2020
return (-prob * log(prob)).sum(dim = -1)
2121

22-
def spectral_entropy_reg_loss_hook(optimizer, weight, *args, **kwargs):
22+
def spectral_entropy_reg_loss_hook(optimizer, *args, **kwargs):
2323
loss = torch.tensor(0.).requires_grad_()
2424

2525
for param_group in optimizer.param_groups:
26+
if not param['add_spectral_entropy_reg']:
27+
continue
28+
29+
weight = param['spectral_entropy_reg_weight']
30+
2631
for param in param_group['params']:
2732
if param.ndim < 2:
2833
continue
@@ -80,7 +85,9 @@ def __init__(
8085
grokfast = grokfast,
8186
grokfast_alpha = grokfast_alpha,
8287
grokfast_lamb = grokfast_lamb,
83-
grokfast_after_step = grokfast_after_step
88+
grokfast_after_step = grokfast_after_step,
89+
add_spectral_entropy_reg = add_spectral_entropy_reg,
90+
spectral_entropy_reg_weight = spectral_entropy_reg_weight
8491
)
8592

8693
super().__init__(params, defaults)
@@ -91,7 +98,7 @@ def __init__(
9198
if not add_spectral_entropy_reg:
9299
return
93100

94-
self.register_step_pre_hook(partial(spectral_entropy_reg_loss_hook, self, spectral_entropy_reg_weight))
101+
self.register_step_pre_hook(partial(spectral_entropy_reg_loss_hook, self))
95102

96103
def turn_on_grokfast(self):
97104
for group in self.param_groups:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "grokfast-pytorch"
3-
version = "0.0.9"
3+
version = "0.0.10"
44
description = "Grokfast"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)