Skip to content

Commit 25e96d3

Browse files
authored
Fix Keras imports. (#2829)
* Fix Keras imports. * Check TF version instead of importing in try-catch.
1 parent 9c1642a commit 25e96d3

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

tensorflow_addons/optimizers/discriminative_layer_training.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,18 @@
1818

1919
import tensorflow as tf
2020

21+
from packaging.version import Version
2122
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2223
from typeguard import typechecked
2324

24-
from keras import backend
25-
from keras.utils import tf_utils
25+
if Version(tf.__version__).release >= Version("2.13").release:
26+
# New versions of Keras require importing from `keras.src` when
27+
# importing internal symbols.
28+
from keras.src import backend
29+
from keras.src.utils import tf_utils
30+
else:
31+
from keras import backend
32+
from keras.utils import tf_utils
2633

2734

2835
@tf.keras.utils.register_keras_serializable(package="Addons")

tensorflow_addons/utils/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from tensorflow_addons import options
2727
from tensorflow_addons.utils import resource_loader
2828

29-
if Version(tf.__version__) >= Version("2.9"):
29+
if Version(tf.__version__).release >= Version("2.13").release:
30+
# New versions of Keras require importing from `keras.src` when
31+
# importing internal symbols.
32+
from keras.src.testing_infra.test_utils import layer_test # noqa: F401
33+
elif Version(tf.__version__) >= Version("2.9"):
3034
from keras.testing_infra.test_utils import layer_test # noqa: F401
3135
else:
3236
from keras.testing_utils import layer_test # noqa: F401

tensorflow_addons/utils/types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@
2020
import numpy as np
2121
import tensorflow as tf
2222

23+
from packaging.version import Version
24+
2325
# TODO: Remove once https://github.com/tensorflow/tensorflow/issues/44613 is resolved
24-
if tf.__version__[:3] > "2.5":
26+
if Version(tf.__version__).release >= Version("2.13").release:
27+
# New versions of Keras require importing from `keras.src` when
28+
# importing internal symbols.
29+
from keras.src.engine import keras_tensor
30+
elif Version(tf.__version__).release >= Version("2.5").release:
2531
from keras.engine import keras_tensor
2632
else:
2733
from tensorflow.python.keras.engine import keras_tensor

0 commit comments

Comments
 (0)