Skip to content

Commit 8c2d84f

Browse files
bot-of-gabrieldemarmiessegithub-actions[bot]gabrieldemarmiesse
authored
[Backport r0.8] Fix LAMB optimizer regex parsing (#1555)
* Fix LAMB optimizer regex parsing * Fix conflict. Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: gabrieldemarmiesse <gabrieldemarmiesse@gmail.com>
1 parent a9e86e8 commit 8c2d84f

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

tensorflow_addons/optimizers/lamb.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
import re
22-
from typing import Optional, Union, Callable
22+
from typing import Optional, Union, Callable, List
2323
from typeguard import typechecked
2424

2525
import tensorflow as tf
@@ -42,8 +42,8 @@ def __init__(
4242
beta_2: FloatTensorLike = 0.999,
4343
epsilon: FloatTensorLike = 1e-6,
4444
weight_decay_rate: FloatTensorLike = 0.0,
45-
exclude_from_weight_decay: Optional[str] = None,
46-
exclude_from_layer_adaptation: Optional[str] = None,
45+
exclude_from_weight_decay: Optional[List[str]] = None,
46+
exclude_from_layer_adaptation: Optional[List[str]] = None,
4747
name: str = "LAMB",
4848
**kwargs
4949
):
@@ -59,10 +59,10 @@ def __init__(
5959
The exponential decay rate for the 2nd moment estimates.
6060
epsilon: A small constant for numerical stability.
6161
weight_decay_rate: weight decay rate.
62-
exclude_from_weight_decay: comma separated name patterns of
62+
exclude_from_weight_decay: List of regex patterns of
6363
variables excluded from weight decay. Variables whose name
6464
contain a substring matching the pattern will be excluded.
65-
exclude_from_layer_adaptation: comma separated name patterns of
65+
exclude_from_layer_adaptation: List of regex patterns of
6666
variables excluded from layer adaptation. Variables whose name
6767
contain a substring matching the pattern will be excluded.
6868
name: Optional name for the operations created when applying

tensorflow_addons/optimizers/lamb_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,20 @@ def test_get_config(self):
401401
config = opt.get_config()
402402
self.assertEqual(config["learning_rate"], 1e-4)
403403

404+
def test_exclude_weight_decay(self):
405+
opt = lamb.LAMB(
406+
0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]
407+
)
408+
assert opt._do_use_weight_decay("var0")
409+
assert not opt._do_use_weight_decay("var1")
410+
assert not opt._do_use_weight_decay("var1_weight")
411+
412+
def test_exclude_layer_adaptation(self):
413+
opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"])
414+
assert opt._do_layer_adaptation("var0")
415+
assert not opt._do_layer_adaptation("var1")
416+
assert not opt._do_layer_adaptation("var1_weight")
417+
404418

405419
if __name__ == "__main__":
406420
tf.test.main()

0 commit comments

Comments
 (0)