Skip to content

Commit 094c23a

Browse files
committed
reminder
1 parent a6e24be commit 094c23a

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

grokfast_pytorch/grokfast.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,20 @@ def __init__(
2222
grokfast = True,
2323
grokfast_alpha = 0.98,
2424
grokfast_lamb = 2.,
25-
grokfast_after_step = 0
25+
grokfast_after_step = 0,
26+
normalize_lr = True
2627
):
2728
assert lr > 0.
2829
assert all([0. <= beta <= 1. for beta in betas])
2930
assert weight_decay >= 0.
3031
assert eps > 0.
3132

33+
# in order for fair comparison
34+
# reduce the learning rate by a factor of (1 + grokfast_lamb)
35+
36+
if normalize_lr:
37+
lr /= (1. + grokfast_lamb)
38+
3239
self._init_lr = lr
3340

3441
defaults = dict(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "grokfast-pytorch"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
description = "Grokfast"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)