Skip to content

Commit 8db3d30

Browse files
committed
reinsert norm to spell, changed loss as mean
1 parent bda2010 commit 8db3d30

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

asr/models/las/network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
200200
self.rnn_num_layers = rnn_num_layers
201201
self.rnns = rnn_type(input_size=(Hy + Hc), hidden_size=Hs, num_layers=rnn_num_layers,
202202
bias=True, bidirectional=False, batch_first=True)
203-
#self.norm = nn.LayerNorm(Hs, elementwise_affine=False)
203+
self.norm = nn.LayerNorm(Hs, elementwise_affine=False)
204204

205205
self.attention = Attention(state_vec_size=Hs, listen_vec_size=Hc,
206206
apply_proj=apply_attend_proj, proj_hidden_size=proj_hidden_size,
@@ -242,7 +242,7 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
242242

243243
for t in range(self.max_seq_lens):
244244
s, hidden = self.rnns(x, hidden)
245-
#s = self.norm(s)
245+
s = self.norm(s)
246246
c, a = self.attention(s, h, in_mask)
247247
y_hat = self.chardist(torch.cat([s, c], dim=-1))
248248
y_hat = self.softmax(y_hat)
@@ -407,7 +407,7 @@ def _eval_forward(self, x, x_seq_lens):
407407
h = self.listen(x, x_seq_lens)
408408
# spell
409409
y_hats, y_hats_seq_lens, _ = self.spell(h, x_seq_lens)
410-
y_hats_seq_lens[y_hats.seq_lens.ne(self.spell.max_seq_lens)].sub_(self.spell.num_eos)
410+
y_hats_seq_lens[y_hats_seq_lens.ne(self.spell.max_seq_lens)].sub_(self.spell.num_eos)
411411

412412
# return with seq lens without sos and eos
413413
y_hats = self.log(y_hats[:, :, :-2])

asr/models/las/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class LASTrainer(NonSplitTrainer):
2727
def __init__(self, *args, **kwargs):
2828
super().__init__(*args, **kwargs)
2929

30-
self.loss = nn.NLLLoss(reduction='none', ignore_index=self.model.blk)
30+
self.loss = nn.NLLLoss(reduction='mean', ignore_index=self.model.blk)
3131

3232
self.tfr_scheduler = TFRScheduler(self.model, ranges=(0.9, 0.0), warm_up=0, epochs=9, restart=True)
3333
#self.tfr_scheduler.step(9)
@@ -68,7 +68,8 @@ def unit_train(self, data):
6868
ys_hat = ys_hat.float()
6969
if self.use_cuda:
7070
ys_lens = ys_lens.cuda()
71-
loss = self.loss(ys_hat.transpose(1, 2), ys.long()).sum(dim=-1).div(ys_lens.float()).mean()
71+
loss = self.loss(ys_hat.transpose(1, 2), ys.long())
72+
#loss = self.loss(ys_hat.transpose(1, 2), ys.long()).sum(dim=-1).div(ys_lens.float()).mean()
7273
#if ys_hat_lens is None:
7374
# logger.debug("the batch includes a data with label_lens > max_seq_lens: ignore the entire batch")
7475
# loss.mul_(0)

0 commit comments

Comments
 (0)