We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5e936a3 commit d80a693Copy full SHA for d80a693
grokfast_pytorch/grokfast.py
@@ -96,15 +96,16 @@ def step(
96
97
should_grokfast = grokfast and steps > grokfast_after_step
98
99
- if should_grokfast and not 'grok_exp_avg' in state:
100
- # maintain an ema of the grad
101
- # for amplifying slow gradients, as paper claims it accelerates generalization
102
-
103
- state['grok_exp_avg'] = grad.clone()
104
105
# take care of grok fast if turned on
106
107
if should_grokfast:
+
+ if 'grok_exp_avg' not in state:
+ # maintain an ema of the grad
+ # for amplifying slow gradients, as paper claims it accelerates generalization
+ state['grok_exp_avg'] = grad.clone()
108
109
grok_exp_avg = state['grok_exp_avg']
110
111
# update grok exp avg
0 commit comments