1616
1717import tensorflow as tf
1818from tensorflow_addons .utils .types import FloatTensorLike
19+ from tensorflow_addons .optimizers .utils import is_variable_matched_by_regexes
1920
2021from typeguard import typechecked
21- from typing import Union , Callable , Type
22+ from typing import Union , Callable , Type , Optional , List
2223
2324
2425class DecoupledWeightDecayExtension :
@@ -71,24 +72,40 @@ def __init__(self, weight_decay, *args, **kwargs):
7172 """
7273
7374 @typechecked
74- def __init__ (self , weight_decay : Union [FloatTensorLike , Callable ], ** kwargs ):
75+ def __init__ (
76+ self ,
77+ weight_decay : Union [FloatTensorLike , Callable ],
78+ exclude_from_weight_decay : Optional [List [str ]] = None ,
79+ ** kwargs ,
80+ ):
7581 """Extension class that adds weight decay to an optimizer.
7682
7783 Args:
7884 weight_decay: A `Tensor`, a floating point value, or a schedule
7985 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
8086 to decay the variable by, in the update step.
87+ exclude_from_weight_decay: List of regex patterns of
88+ variables excluded from weight decay. Variables whose name
89+ contain a substring matching the pattern will be excluded.
90+ Note `decay_var_list` in `minimize` or `apply_gradients` takes
91+ priority over `exclude_from_weight_decay` if specified.
8192 **kwargs: Optional list or tuple or set of `Variable` objects to
8293 decay.
8394 """
8495 wd = kwargs .pop ("weight_decay" , weight_decay )
8596 super ().__init__ (** kwargs )
8697 self ._decay_var_list = None # is set in minimize or apply_gradients
8798 self ._set_hyper ("weight_decay" , wd )
99+ self .exclude_from_weight_decay = exclude_from_weight_decay
88100
89101 def get_config (self ):
90102 config = super ().get_config ()
91- config .update ({"weight_decay" : self ._serialize_hyperparameter ("weight_decay" )})
103+ config .update (
104+ {
105+ "weight_decay" : self ._serialize_hyperparameter ("weight_decay" ),
106+ "exclude_from_weight_decay" : self .exclude_from_weight_decay ,
107+ }
108+ )
92109 return config
93110
94111 @classmethod
@@ -130,7 +147,8 @@ def minimize(
130147 grad_loss: Optional. A `Tensor` holding the gradient computed for
131148 `loss`.
132149 decay_var_list: Optional list of variables to be decayed. Defaults
133- to all variables in var_list.
150+ to all variables in var_list. Note `decay_var_list` takes
151+ priority over `exclude_from_weight_decay` if specified.
134152 name: Optional name for the returned operation.
135153 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
136154 `Tensor`, the tape that computed the `loss` must be provided.
@@ -154,10 +172,11 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
154172
155173 Args:
156174 grads_and_vars: List of (gradient, variable) pairs.
157- name: Optional name for the returned operation. Default to the
175+ name: Optional name for the returned operation. Default to the
158176 name passed to the `Optimizer` constructor.
159177 decay_var_list: Optional list of variables to be decayed. Defaults
160- to all variables in var_list.
178+ to all variables in var_list. Note `decay_var_list` takes
179+ priority over `exclude_from_weight_decay` if specified.
161180 **kwargs: Additional arguments to pass to the base optimizer's
162181 apply_gradient method, e.g., TF2.2 added an argument
163182 `experimental_aggregate_gradients`.
@@ -173,7 +192,7 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
173192 return super ().apply_gradients (grads_and_vars , name = name , ** kwargs )
174193
175194 def _decay_weights_op (self , var , apply_state = None ):
176- if not self ._decay_var_list or var . ref () in self . _decay_var_list :
195+ if self ._do_use_weight_decay ( var ) :
177196 var_device , var_dtype = var .device , var .dtype .base_dtype
178197 coefficients = (apply_state or {}).get (
179198 (var_device , var_dtype )
@@ -183,7 +202,7 @@ def _decay_weights_op(self, var, apply_state=None):
183202 return tf .no_op ()
184203
185204 def _decay_weights_sparse_op (self , var , indices , apply_state = None ):
186- if not self ._decay_var_list or var . ref () in self . _decay_var_list :
205+ if self ._do_use_weight_decay ( var ) :
187206 var_device , var_dtype = var .device , var .dtype .base_dtype
188207 coefficients = (apply_state or {}).get (
189208 (var_device , var_dtype )
@@ -226,6 +245,12 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
226245 grad , var , indices , apply_state = apply_state
227246 )
228247
248+ def _do_use_weight_decay (self , var ):
249+ """Whether to use L2 weight decay for `var`."""
250+ if self ._decay_var_list and var .ref () in self ._decay_var_list :
251+ return True
252+ return not is_variable_matched_by_regexes (var , self .exclude_from_weight_decay )
253+
229254
230255@typechecked
231256def extend_with_decoupled_weight_decay (
@@ -243,9 +268,13 @@ def extend_with_decoupled_weight_decay(
243268 The API of the new optimizer class slightly differs from the API of the
244269 base optimizer:
245270 - The first argument to the constructor is the weight decay rate.
271+ - Optional keyword argument `exclude_from_weight_decay` accepts list of
272+ regex patterns of variables excluded from weight decay. Variables whose
273+ name contain a substring matching the pattern will be excluded.
246274 - `minimize` and `apply_gradients` accept the optional keyword argument
247275 `decay_var_list`, which specifies the variables that should be decayed.
248- If `None`, all variables that are optimized are decayed.
276+ Note this takes priority over `exclude_from_weight_decay` if specified.
277+ If both `None`, all variables that are optimized are decayed.
249278
250279 Usage example:
251280 ```python
@@ -376,12 +405,14 @@ def __init__(
376405 nesterov: boolean. Whether to apply Nesterov momentum.
377406 name: Optional name prefix for the operations created when applying
378407 gradients. Defaults to 'SGD'.
379- **kwargs: keyword arguments. Allowed to be {`clipnorm`,
380- `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
381- norm; `clipvalue` is clip gradients by value, `decay` is
382- included for backward compatibility to allow time inverse decay
383- of learning rate. `lr` is included for backward compatibility,
384- recommended to use `learning_rate` instead.
408+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
409+ `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
410+ gradients by norm; `clipvalue` is clip gradients by value.
411+ `decay` is included for backward compatibility to allow time
412+ inverse decay of learning rate. `lr` is included for backward
413+ compatibility, recommended to use `learning_rate` instead.
414+ `exclude_from_weight_decay` accepts list of regex patterns of
415+ variables excluded from weight decay.
385416 """
386417 super ().__init__ (
387418 weight_decay ,
@@ -466,12 +497,14 @@ def __init__(
466497 beyond".
467498 name: Optional name for the operations created when applying
468499 gradients. Defaults to "AdamW".
469- **kwargs: keyword arguments. Allowed to be {`clipnorm`,
470- `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
471- norm; `clipvalue` is clip gradients by value, `decay` is
472- included for backward compatibility to allow time inverse decay
473- of learning rate. `lr` is included for backward compatibility,
474- recommended to use `learning_rate` instead.
500+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
501+ `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
502+ gradients by norm; `clipvalue` is clip gradients by value.
503+ `decay` is included for backward compatibility to allow time
504+ inverse decay of learning rate. `lr` is included for backward
505+ compatibility, recommended to use `learning_rate` instead.
506+ `exclude_from_weight_decay` accepts list of regex patterns of
507+ variables excluded from weight decay.
475508 """
476509 super ().__init__ (
477510 weight_decay ,
0 commit comments