1919"""
2020
2121import re
22- from typing import Optional , Union , Callable
22+ from typing import Optional , Union , Callable , List
2323from typeguard import typechecked
2424
2525import tensorflow as tf
@@ -42,8 +42,8 @@ def __init__(
4242 beta_2 : FloatTensorLike = 0.999 ,
4343 epsilon : FloatTensorLike = 1e-6 ,
4444 weight_decay_rate : FloatTensorLike = 0.0 ,
45- exclude_from_weight_decay : Optional [str ] = None ,
46- exclude_from_layer_adaptation : Optional [str ] = None ,
45+ exclude_from_weight_decay : Optional [List [ str ] ] = None ,
46+ exclude_from_layer_adaptation : Optional [List [ str ] ] = None ,
4747 name : str = "LAMB" ,
4848 ** kwargs
4949 ):
@@ -59,10 +59,10 @@ def __init__(
5959 The exponential decay rate for the 2nd moment estimates.
6060 epsilon: A small constant for numerical stability.
6161 weight_decay_rate: weight decay rate.
62- exclude_from_weight_decay: comma separated name patterns of
62+ exclude_from_weight_decay: List of regex patterns of
6363 variables excluded from weight decay. Variables whose name
6464 contain a substring matching the pattern will be excluded.
65- exclude_from_layer_adaptation: comma separated name patterns of
65+ exclude_from_layer_adaptation: List of regex patterns of
6666 variables excluded from layer adaptation. Variables whose name
6767 contain a substring matching the pattern will be excluded.
6868 name: Optional name for the operations created when applying
0 commit comments