@@ -209,7 +209,55 @@ def test_masking(self):
209209 alignment = self .evaluate (alignment )
210210 self .assertEqual (np .sum (np .triu (alignment , k = 1 )), 0 )
211211
212- # TODO(scottzhu): Add tests for model.compile(run_eagerly=True)
212+ @parameterized .named_parameters (
213+ ("luong" , wrapper .LuongAttention ),
214+ ("luong_monotonic" , wrapper .LuongMonotonicAttention ),
215+ ("bahdanau" , wrapper .BahdanauAttention ),
216+ ("bahdanau_monotonic" , wrapper .BahdanauMonotonicAttention ),
217+ )
218+ def test_memory_re_setup (self , attention_cls ):
219+ class MyModel (tf .keras .models .Model ):
220+ def __init__ (self , vocab , embedding_dim , memory_size , units ):
221+ super (MyModel , self ).__init__ ()
222+ self .emb = tf .keras .layers .Embedding (
223+ vocab , embedding_dim , mask_zero = True )
224+ self .encoder = tf .keras .layers .LSTM (
225+ memory_size , return_sequences = True )
226+ self .attn_mch = attention_cls (units )
227+
228+ def call (self , inputs ):
229+ enc_input , query , state = inputs
230+ mask = self .emb .compute_mask (enc_input )
231+ enc_input = self .emb (enc_input )
232+ enc_output = self .encoder (enc_input , mask = mask )
233+ # To ensure manual resetting also works in the graph mode,
234+ # we call the attention mechanism twice.
235+ self .attn_mch (enc_output , mask = mask , setup_memory = True )
236+ self .attn_mch (enc_output , mask = mask , setup_memory = True )
237+ score = self .attn_mch ([query , state ])
238+ return score
239+
240+ vocab = 20
241+ embedding_dim = 6
242+ num_batches = 5
243+
244+ model = MyModel (vocab , embedding_dim , self .memory_size , self .units )
245+ if tf .executing_eagerly ():
246+ model .compile ("rmsprop" , "mse" , run_eagerly = True )
247+ else :
248+ model .compile ("rmsprop" , "mse" )
249+
250+ x = np .random .randint (
251+ vocab , size = (num_batches * self .batch , self .timestep ))
252+ x_test = np .random .randint (
253+ vocab , size = (num_batches * self .batch , self .timestep ))
254+ y = np .random .randn (num_batches * self .batch , self .timestep )
255+
256+ query = np .tile (self .query , [num_batches , 1 ])
257+ state = np .tile (self .state , [num_batches , 1 ])
258+
259+ model .fit ([x , query , state ], (y , y ), batch_size = self .batch )
260+ model .predict_on_batch ([x_test , query , state ])
213261
214262
215263class ResultSummary (
0 commit comments