@@ -194,6 +194,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
194194 self .eos = label_vec_size - 1 if eos is None else eos
195195 self .max_seq_lens = max_seq_lens
196196 self .num_eos = 3
197+ self .tfr = 1.
197198
198199 Hs , Hc , Hy = rnn_hidden_size , listen_vec_size , label_vec_size
199200
@@ -222,6 +223,9 @@ def get_mask(self, h, seq_lens):
222223 mask [b , seq_lens [b ]:] = 0.
223224 return mask
224225
226+ def _is_sample_step (self ):
227+ return np .random .random_sample () < self .tfr
228+
225229 def forward (self , h , x_seq_lens , y = None , y_seq_lens = None ):
226230 batch_size = h .size (0 )
227231 sos = int2onehot (h .new_full ((batch_size , 1 ), self .sos ), num_classes = self .label_vec_size ).float ()
@@ -261,9 +265,9 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
261265 if y_hats_seq_lens .le (t + 1 ).all ():
262266 break
263267
264- if y is None :
268+ if y is None or not self . _is_sample_step (): # non sampling step
265269 x = torch .cat ([y_hat , c ], dim = - 1 )
266- elif t < y .size (1 ): # teach force
270+ elif t < y .size (1 ): # scheduled sampling step
267271 x = torch .cat ([y .narrow (1 , t , 1 ), c ], dim = - 1 )
268272 else :
269273 x = torch .cat ([eos , c ], dim = - 1 )
@@ -335,7 +339,6 @@ def __init__(self, label_vec_size=p.NUM_CTC_LABELS, listen_vec_size=256,
335339 self .eos = self .label_vec_size - 1
336340
337341 self .num_heads = num_attend_heads
338- self .tfr = 1.
339342
340343 self .listen = Listener (listen_vec_size = listen_vec_size , input_folding = input_folding , rnn_type = nn .LSTM ,
341344 rnn_hidden_size = listen_vec_size , rnn_num_layers = 4 , bidirectional = True ,
@@ -357,9 +360,6 @@ def forward(self, x, x_seq_lens, y=None, y_seq_lens=None):
357360 else :
358361 return self ._eval_forward (x , x_seq_lens )
359362
360- def _is_teacher_force (self ):
361- return np .random .random_sample () < self .tfr
362-
363363 def _train_forward (self , x , x_seq_lens , y , y_seq_lens ):
364364 # to remove the case of x_seq_lens < y_seq_lens and y_seq_lens > max_seq_lens
365365 bi = x_seq_lens .gt (y_seq_lens ) * y_seq_lens .lt (self .spell .max_seq_lens )
@@ -376,15 +376,11 @@ def _train_forward(self, x, x_seq_lens, y, y_seq_lens):
376376 ys = nn .utils .rnn .pad_sequence (ys , batch_first = True , padding_value = self .blk )
377377 ys , ys_seq_lens = ys [bi ], y_seq_lens [bi ] + self .spell .num_eos
378378
379- if self ._is_teacher_force ():
380- # speller with teach force rate including noise
381- floor = np .random .random_sample () * 1e-2
382- yss = int2onehot (ys , num_classes = self .label_vec_size , floor = floor ).float ()
383- noise = torch .rand_like (yss ) * 0.1
384- yss = yss * noise
385- y_hats , y_hats_seq_lens , self .attentions = self .spell (h , x_seq_lens , yss , ys_seq_lens )
386- else :
387- y_hats , y_hats_seq_lens , self .attentions = self .spell (h , x_seq_lens )
379+ floor = np .random .random_sample () * 0.1
380+ yss = int2onehot (ys , num_classes = self .label_vec_size , floor = floor ).float ()
381+ noise = torch .rand_like (yss ) * 0.1
382+ yss = F .softmax (yss * noise , dim = - 1 )
383+ y_hats , y_hats_seq_lens , self .attentions = self .spell (h , x_seq_lens , yss , ys_seq_lens )
388384
389385 # add regions to attentions
390386 self .regions = torch .IntTensor ([(frames - 1 , labels - 1 ) for frames , labels in zip (x_seq_lens , ys_seq_lens )])
0 commit comments