Skip to content

Commit d80a693

Browse files
committed
last cleanup
1 parent 5e936a3 commit d80a693

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

grokfast_pytorch/grokfast.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,16 @@ def step(
9696

9797
should_grokfast = grokfast and steps > grokfast_after_step
9898

99-
if should_grokfast and not 'grok_exp_avg' in state:
100-
# maintain an ema of the grad
101-
# for amplifying slow gradients, as paper claims it accelerates generalization
102-
103-
state['grok_exp_avg'] = grad.clone()
104-
10599
# take care of grok fast if turned on
106100

107101
if should_grokfast:
102+
103+
if 'grok_exp_avg' not in state:
104+
# maintain an ema of the grad
105+
# for amplifying slow gradients, as paper claims it accelerates generalization
106+
107+
state['grok_exp_avg'] = grad.clone()
108+
108109
grok_exp_avg = state['grok_exp_avg']
109110

110111
# update grok exp avg

0 commit comments

Comments
 (0)