Skip to content

Commit 4436057

Browse files
authored
Remove one of the private import of tf.keras (#2655)
* Remove one of the private import of tf.keras * Fix format issue * More format fix * Even more format fix * Fix for flake8 * Remove the unnecessary white space.
1 parent b8cab7f commit 4436057

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tensorflow_addons/seq2seq/attention_wrapper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@
3434
from typeguard import typechecked
3535
from typing import Optional, Callable, Union, List
3636

37-
# TODO: Find public API alternatives to these
38-
from tensorflow.python.keras.engine import base_layer_utils
39-
4037

4138
class AttentionMechanism(tf.keras.layers.Layer):
4239
"""Base class for attention mechanisms.
@@ -275,7 +272,11 @@ def setup_memory(self, memory, memory_sequence_length=None, memory_mask=None):
275272
# passed from __call__(), which does not have proper keras metadata.
276273
# TODO(omalleyt12): Remove this hack once the mask the has proper
277274
# keras history.
278-
base_layer_utils.mark_checked(self.values)
275+
276+
def _mark_checked(tensor):
277+
tensor._keras_history_checked = True # pylint: disable=protected-access
278+
279+
tf.nest.map_structure(_mark_checked, self.values)
279280
if self.memory_layer is not None:
280281
self.keys = self.memory_layer(self.values)
281282
else:

0 commit comments

Comments
 (0)