|
1 | 1 | from __future__ import annotations |
2 | | -from typing import Tuple, Callable |
| 2 | +from typing import Callable |
| 3 | + |
| 4 | +from functools import partial |
3 | 5 |
|
4 | 6 | import torch |
5 | 7 | from torch.optim.optimizer import Optimizer |
|
9 | 11 | def exists(val): |
10 | 12 | return val is not None |
11 | 13 |
|
| 14 | +# tensor helpers |
| 15 | + |
| 16 | +def log(t, eps = 1e-20): |
| 17 | + return t.clamp(min = eps).log() |
| 18 | + |
| 19 | +def entropy(prob): |
| 20 | + return (-prob * log(prob)).sum(dim = -1) |
| 21 | + |
| 22 | +def spectral_entropy_reg_loss_hook(optimizer, weight, *args, **kwargs): |
| 23 | + loss = torch.tensor(0.).requires_grad_() |
| 24 | + |
| 25 | + for param_group in optimizer.param_groups: |
| 26 | + for param in param_group['params']: |
| 27 | + if param.ndim < 2: |
| 28 | + continue |
| 29 | + |
| 30 | + *_, row, col = param.shape |
| 31 | + reshaped_param = param.reshape(-1, row, col) |
| 32 | + |
| 33 | + singular_values = torch.linalg.svdvals(reshaped_param) |
| 34 | + spectral_prob = singular_values.softmax(dim = -1) |
| 35 | + spectral_entropy = entropy(spectral_prob).sum() |
| 36 | + loss = loss + spectral_entropy |
| 37 | + |
| 38 | + (loss * weight).backward() |
| 39 | + |
12 | 40 | # class |
13 | 41 |
|
14 | 42 | class GrokFastAdamW(Optimizer): |
15 | 43 | def __init__( |
16 | 44 | self, |
17 | 45 | params, |
18 | 46 | lr = 1e-4, |
19 | | - betas: Tuple[float, float] = (0.9, 0.99), |
| 47 | + betas: tuple[float, float] = (0.9, 0.99), |
20 | 48 | weight_decay = 0., |
21 | 49 | eps = 1e-8, |
22 | 50 | regen_reg_rate = 0., |
23 | 51 | grokfast = True, |
24 | 52 | grokfast_alpha = 0.98, |
25 | 53 | grokfast_lamb = 2., |
26 | 54 | grokfast_after_step = 0, |
27 | | - normalize_lr = True |
| 55 | + normalize_lr = True, |
| 56 | + add_spectral_entropy_reg = False, |
| 57 | + spectral_entropy_reg_weight = 0.1 |
28 | 58 | ): |
29 | 59 | assert lr > 0. |
30 | 60 | assert all([0. <= beta <= 1. for beta in betas]) |
@@ -55,6 +85,14 @@ def __init__( |
55 | 85 |
|
56 | 86 | super().__init__(params, defaults) |
57 | 87 |
|
| 88 | + # maybe spectral entropy reg |
| 89 | + # https://openreview.net/forum?id=07N9jCfIE4 |
| 90 | + |
| 91 | + if not add_spectral_entropy_reg: |
| 92 | + return |
| 93 | + |
| 94 | + self.register_step_pre_hook(partial(spectral_entropy_reg_loss_hook, self, spectral_entropy_reg_weight)) |
| 95 | + |
58 | 96 | def turn_on_grokfast(self): |
59 | 97 | for group in self.param_groups: |
60 | 98 | group['grokfast'] = True |
|
0 commit comments