@@ -200,7 +200,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
200200 self .rnn_num_layers = rnn_num_layers
201201 self .rnns = rnn_type (input_size = (Hy + Hc ), hidden_size = Hs , num_layers = rnn_num_layers ,
202202 bias = True , bidirectional = False , batch_first = True )
203- # self.norm = nn.LayerNorm(Hs, elementwise_affine=False)
203+ self .norm = nn .LayerNorm (Hs , elementwise_affine = False )
204204
205205 self .attention = Attention (state_vec_size = Hs , listen_vec_size = Hc ,
206206 apply_proj = apply_attend_proj , proj_hidden_size = proj_hidden_size ,
@@ -242,7 +242,7 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
242242
243243 for t in range (self .max_seq_lens ):
244244 s , hidden = self .rnns (x , hidden )
245- # s = self.norm(s)
245+ s = self .norm (s )
246246 c , a = self .attention (s , h , in_mask )
247247 y_hat = self .chardist (torch .cat ([s , c ], dim = - 1 ))
248248 y_hat = self .softmax (y_hat )
@@ -407,7 +407,7 @@ def _eval_forward(self, x, x_seq_lens):
407407 h = self .listen (x , x_seq_lens )
408408 # spell
409409 y_hats , y_hats_seq_lens , _ = self .spell (h , x_seq_lens )
410- y_hats_seq_lens [y_hats . seq_lens .ne (self .spell .max_seq_lens )].sub_ (self .spell .num_eos )
410+ y_hats_seq_lens [y_hats_seq_lens .ne (self .spell .max_seq_lens )].sub_ (self .spell .num_eos )
411411
412412 # return with seq lens without sos and eos
413413 y_hats = self .log (y_hats [:, :, :- 2 ])
0 commit comments