@@ -157,9 +157,7 @@ def minimize(
157157 Raises:
158158 ValueError: If some of the variables are not `Variable` objects.
159159 """
160- self ._decay_var_list = (
161- set ([v .ref () for v in decay_var_list ]) if decay_var_list else False
162- )
160+ self ._set_decay_var_list (var_list , decay_var_list )
163161 return super ().minimize (
164162 loss , var_list = var_list , grad_loss = grad_loss , name = name , tape = tape
165163 )
@@ -186,9 +184,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
186184 TypeError: If `grads_and_vars` is malformed.
187185 ValueError: If none of the variables have gradients.
188186 """
189- self ._decay_var_list = (
190- set ([v .ref () for v in decay_var_list ]) if decay_var_list else False
191- )
187+ grads_and_vars = list (grads_and_vars )
188+ self ._set_decay_var_list ((v for _ , v in grads_and_vars ), decay_var_list )
192189 return super ().apply_gradients (grads_and_vars , name = name , ** kwargs )
193190
194191 def _decay_weights_op (self , var , apply_state = None ):
@@ -245,11 +242,23 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
245242 grad , var , indices , apply_state = apply_state
246243 )
247244
245+ def _set_decay_var_list (self , var_list , decay_var_list = None ):
246+ if decay_var_list :
247+ self ._decay_var_list = set (v .ref () for v in decay_var_list )
248+ elif self .exclude_from_weight_decay :
249+ self ._decay_var_list = set (
250+ v .ref ()
251+ for v in var_list
252+ if not is_variable_matched_by_regexes (v , self .exclude_from_weight_decay )
253+ )
254+ else :
255+ self ._decay_var_list = None
256+
248257 def _do_use_weight_decay (self , var ):
249258 """Whether to use L2 weight decay for `var`."""
250- if self ._decay_var_list and var . ref () in self . _decay_var_list :
259+ if self ._decay_var_list is None :
251260 return True
252- return not is_variable_matched_by_regexes ( var , self .exclude_from_weight_decay )
261+ return var . ref () in self ._decay_var_list
253262
254263
255264@typechecked
0 commit comments