Skip to content

Commit 339159f

Browse files
Point optimizer to tf.keras.optimizer.legacy.Optimizer to be compatib… (#2706)
* Point optimizer to tf.keras.optimizer.legacy.Optimizer to be compatible with Keras optimizer migration * small fix * add version control * small fix * Update discriminative_layer_training.py * fix version control * small fix * move optimizer class to __init__.py * small fix * fix problems * small fix * Rename BaseOptimizer to KerasLegacyOptimizer * exclude keras optimizer from type check * fix import
1 parent 3e264f9 commit 339159f

21 files changed

+105
-35
lines changed

tensorflow_addons/optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Additional optimizers that conform to Keras API."""
1616

17+
from tensorflow_addons.optimizers.constants import KerasLegacyOptimizer
1718
from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper
1819
from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient
1920
from tensorflow_addons.optimizers.cyclical_learning_rate import CyclicalLearningRate

tensorflow_addons/optimizers/adabelief.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import tensorflow as tf
1818
from tensorflow_addons.utils.types import FloatTensorLike
1919

20+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2021
from typing import Union, Callable, Dict
2122

2223

2324
@tf.keras.utils.register_keras_serializable(package="Addons")
24-
class AdaBelief(tf.keras.optimizers.Optimizer):
25+
class AdaBelief(KerasLegacyOptimizer):
2526
"""Variant of the Adam optimizer.
2627
2728
It achieves fast convergence as Adam and generalization comparable to SGD.

tensorflow_addons/optimizers/average_wrapper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import warnings
1818

1919
import tensorflow as tf
20+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2021
from tensorflow_addons.utils import types
21-
2222
from typeguard import typechecked
2323

2424

25-
class AveragedOptimizerWrapper(tf.keras.optimizers.Optimizer, metaclass=abc.ABCMeta):
25+
class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta):
2626
@typechecked
2727
def __init__(
2828
self, optimizer: types.Optimizer, name: str = "AverageOptimizer", **kwargs
@@ -32,9 +32,12 @@ def __init__(
3232
if isinstance(optimizer, str):
3333
optimizer = tf.keras.optimizers.get(optimizer)
3434

35-
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
35+
if not isinstance(
36+
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
37+
):
3638
raise TypeError(
37-
"optimizer is not an object of tf.keras.optimizers.Optimizer"
39+
"optimizer is not an object of tf.keras.optimizers.Optimizer "
40+
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
3841
)
3942

4043
self._optimizer = optimizer

tensorflow_addons/optimizers/cocob.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from typeguard import typechecked
1818
import tensorflow as tf
1919

20+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
21+
2022

2123
@tf.keras.utils.register_keras_serializable(package="Addons")
22-
class COCOB(tf.keras.optimizers.Optimizer):
24+
class COCOB(KerasLegacyOptimizer):
2325
"""Optimizer that implements COCOB Backprop Algorithm
2426
2527
Reference:

tensorflow_addons/optimizers/conditional_gradient.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""Conditional Gradient optimizer."""
1616

1717
import tensorflow as tf
18+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
1819
from tensorflow_addons.utils.types import FloatTensorLike
1920

2021
from typeguard import typechecked
2122
from typing import Union, Callable
2223

2324

2425
@tf.keras.utils.register_keras_serializable(package="Addons")
25-
class ConditionalGradient(tf.keras.optimizers.Optimizer):
26+
class ConditionalGradient(KerasLegacyOptimizer):
2627
"""Optimizer that implements the Conditional Gradient optimization.
2728
2829
This optimizer helps handle constraints well.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import importlib
16+
import tensorflow as tf
17+
18+
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
19+
KerasLegacyOptimizer = tf.keras.optimizers.legacy.Optimizer
20+
else:
21+
KerasLegacyOptimizer = tf.keras.optimizers.Optimizer

tensorflow_addons/optimizers/cyclical_learning_rate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
```
5959
6060
You can pass this schedule directly into a
61-
`tf.keras.optimizers.Optimizer` as the learning rate.
61+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
6262
6363
Args:
6464
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
@@ -146,7 +146,7 @@ def __init__(
146146
```
147147
148148
You can pass this schedule directly into a
149-
`tf.keras.optimizers.Optimizer` as the learning rate.
149+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
150150
151151
Args:
152152
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
@@ -215,7 +215,7 @@ def __init__(
215215
```
216216
217217
You can pass this schedule directly into a
218-
`tf.keras.optimizers.Optimizer` as the learning rate.
218+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
219219
220220
Args:
221221
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
@@ -286,7 +286,7 @@ def __init__(
286286
```
287287
288288
You can pass this schedule directly into a
289-
`tf.keras.optimizers.Optimizer` as the learning rate.
289+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
290290
291291
Args:
292292
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or

tensorflow_addons/optimizers/discriminative_layer_training.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
from typing import List, Union
1818

1919
import tensorflow as tf
20+
21+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2022
from typeguard import typechecked
2123

2224
from keras import backend
2325
from keras.utils import tf_utils
2426

2527

2628
@tf.keras.utils.register_keras_serializable(package="Addons")
27-
class MultiOptimizer(tf.keras.optimizers.Optimizer):
29+
class MultiOptimizer(KerasLegacyOptimizer):
2830
"""Multi Optimizer Wrapper for Discriminative Layer Training.
2931
3032
Creates a wrapper around a set of instantiated optimizer layer pairs.
@@ -33,7 +35,7 @@ class MultiOptimizer(tf.keras.optimizers.Optimizer):
3335
Each optimizer will optimize only the weights associated with its paired layer.
3436
This can be used to implement discriminative layer training by assigning
3537
different learning rates to each optimizer layer pair.
36-
`(tf.keras.optimizers.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
38+
`(tf.keras.optimizers.legacy.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
3739
Please note that the layers must be instantiated before instantiating the optimizer.
3840
3941
Args:
@@ -149,7 +151,7 @@ def get_config(self):
149151
@classmethod
150152
def create_optimizer_spec(
151153
cls,
152-
optimizer: tf.keras.optimizers.Optimizer,
154+
optimizer: KerasLegacyOptimizer,
153155
layers_or_model: Union[
154156
tf.keras.Model,
155157
tf.keras.Sequential,

tensorflow_addons/optimizers/lamb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from typeguard import typechecked
2525

2626
import tensorflow as tf
27+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2728
from tensorflow_addons.utils.types import FloatTensorLike
2829
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
2930

3031

3132
@tf.keras.utils.register_keras_serializable(package="Addons")
32-
class LAMB(tf.keras.optimizers.Optimizer):
33+
class LAMB(KerasLegacyOptimizer):
3334
"""Optimizer that implements the Layer-wise Adaptive Moments (LAMB).
3435
3536
See paper [Large Batch Optimization for Deep Learning: Training BERT

tensorflow_addons/optimizers/lazy_adam.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,22 @@
2020
original Adam algorithm, and may lead to different empirical results.
2121
"""
2222

23+
import importlib
2324
import tensorflow as tf
2425
from tensorflow_addons.utils.types import FloatTensorLike
2526

2627
from typeguard import typechecked
2728
from typing import Union, Callable
2829

2930

31+
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
32+
adam_optimizer_class = tf.keras.optimizers.legacy.Adam
33+
else:
34+
adam_optimizer_class = tf.keras.optimizers.Adam
35+
36+
3037
@tf.keras.utils.register_keras_serializable(package="Addons")
31-
class LazyAdam(tf.keras.optimizers.Adam):
38+
class LazyAdam(adam_optimizer_class):
3239
"""Variant of the Adam optimizer that handles sparse updates more
3340
efficiently.
3441

0 commit comments

Comments
 (0)