Skip to content

Commit df256b9

Browse files
kazemnejadguillaumekln
authored andcommitted
Allow manual memory reset in AttentionMechanism (#547)
* Add support for manual memory reset * Better code style in MyModel class Remove solved #TODO
1 parent ea26f14 commit df256b9

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

tensorflow_addons/seq2seq/attention_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def __call__(self, inputs, **kwargs):
179179
inputs: the inputs tensors.
180180
**kwargs: dict, other keyeword arguments for the `__call__()`
181181
"""
182+
# Allow manual memory reset
183+
if kwargs.get('setup_memory', False):
184+
self._memory_initialized = False
185+
182186
if self._memory_initialized:
183187
if len(inputs) not in (2, 3):
184188
raise ValueError(
@@ -188,6 +192,7 @@ def __call__(self, inputs, **kwargs):
188192
# We append the calculated memory here so that the graph will be
189193
# connected.
190194
inputs.append(self.values)
195+
191196
return super(_BaseAttentionMechanism, self).__call__(inputs, **kwargs)
192197

193198
def call(self, inputs, mask=None, setup_memory=False, **kwargs):

tensorflow_addons/seq2seq/attention_wrapper_test.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,55 @@ def test_masking(self):
209209
alignment = self.evaluate(alignment)
210210
self.assertEqual(np.sum(np.triu(alignment, k=1)), 0)
211211

212-
# TODO(scottzhu): Add tests for model.compile(run_eagerly=True)
212+
@parameterized.named_parameters(
213+
("luong", wrapper.LuongAttention),
214+
("luong_monotonic", wrapper.LuongMonotonicAttention),
215+
("bahdanau", wrapper.BahdanauAttention),
216+
("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention),
217+
)
218+
def test_memory_re_setup(self, attention_cls):
219+
class MyModel(tf.keras.models.Model):
220+
def __init__(self, vocab, embedding_dim, memory_size, units):
221+
super(MyModel, self).__init__()
222+
self.emb = tf.keras.layers.Embedding(
223+
vocab, embedding_dim, mask_zero=True)
224+
self.encoder = tf.keras.layers.LSTM(
225+
memory_size, return_sequences=True)
226+
self.attn_mch = attention_cls(units)
227+
228+
def call(self, inputs):
229+
enc_input, query, state = inputs
230+
mask = self.emb.compute_mask(enc_input)
231+
enc_input = self.emb(enc_input)
232+
enc_output = self.encoder(enc_input, mask=mask)
233+
# To ensure manual resetting also works in the graph mode,
234+
# we call the attention mechanism twice.
235+
self.attn_mch(enc_output, mask=mask, setup_memory=True)
236+
self.attn_mch(enc_output, mask=mask, setup_memory=True)
237+
score = self.attn_mch([query, state])
238+
return score
239+
240+
vocab = 20
241+
embedding_dim = 6
242+
num_batches = 5
243+
244+
model = MyModel(vocab, embedding_dim, self.memory_size, self.units)
245+
if tf.executing_eagerly():
246+
model.compile("rmsprop", "mse", run_eagerly=True)
247+
else:
248+
model.compile("rmsprop", "mse")
249+
250+
x = np.random.randint(
251+
vocab, size=(num_batches * self.batch, self.timestep))
252+
x_test = np.random.randint(
253+
vocab, size=(num_batches * self.batch, self.timestep))
254+
y = np.random.randn(num_batches * self.batch, self.timestep)
255+
256+
query = np.tile(self.query, [num_batches, 1])
257+
state = np.tile(self.state, [num_batches, 1])
258+
259+
model.fit([x, query, state], (y, y), batch_size=self.batch)
260+
model.predict_on_batch([x_test, query, state])
213261

214262

215263
class ResultSummary(

0 commit comments

Comments
 (0)