Skip to content

Commit 0522b58

Browse files
committed
phone based scheduled sampling
1 parent cf03bda commit 0522b58

File tree

3 files changed

+86
-23
lines changed

3 files changed

+86
-23
lines changed

asr/models/las/loss.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
from torch.nn.modules.loss import _Loss
3+
4+
from asr.utils.misc import onehot2int
5+
6+
class EditDistanceLoss(_Loss):
7+
__constants__ = ['reduction']
8+
9+
def __init__(self, size_average=None, reduce=None, reduction='mean'):
10+
super().__init__(size_average, reduce, reduction)
11+
12+
def forward(self, input, target, input_seq_lens, target_seq_lens):
13+
"""
14+
input: BxTxH, target: BxN, input_seq_lens: B, target_seq_lens: B
15+
"""
16+
batch_size = input.size(0)
17+
eds = list()
18+
for b in range(batch_size):
19+
x = torch.argmax(input[b, :input_seq_lens[b]], dim=-1)
20+
y = target[b, :target_seq_lens[b]]
21+
d = self.calculate_levenshtein(x, y)
22+
eds.append(d)
23+
loss = torch.FloatTensor(eds)
24+
25+
if self.reduction == 'none':
26+
return loss
27+
elif self.reduction == 'mean':
28+
return loss.mean()
29+
30+
def calculate_levenshtein(self, seq1, seq2):
31+
"""
32+
implement the extension of the Wagner–Fischer dynamic programming algorithm
33+
"""
34+
size_x, size_y = len(seq1), len(seq2)
35+
matrix = torch.zeros((size_x, size_y))
36+
for x in range(size_x):
37+
matrix[x, 0] = x
38+
for y in range(size_y):
39+
matrix[0, y] = y
40+
41+
for x in range(1, size_x):
42+
for y in range(1, size_y):
43+
cost = 0 if seq1[x] == seq2[y] else 1
44+
comps = torch.LongTensor([
45+
matrix[x - 1, y] + 1, # deletion
46+
matrix[x, y - 1] + 1, # insertion
47+
matrix[x - 1, y - 1] + cost, # subtitution
48+
])
49+
matrix[x, y] = torch.min(comps)
50+
if x > 1 and y > 1 and seq1[x] == seq2[y - 1] and seq1[x - 1] == seq2[y]:
51+
comps = torch.LongTensor([
52+
matrix[x, y],
53+
matrix[x - 2, y - 2] + cost, # transposition
54+
])
55+
matrix[x, y] = torch.min(comps)
56+
57+
return matrix[-1, -1]
58+
59+
if __name__ == "__main__":
60+
x = torch.LongTensor([0, 1, 2])
61+
y = torch.LongTensor([0, 2, 1, 3])
62+
l = EditDistanceLoss()
63+
print(l.calculate_levenshtein(x, y))

asr/models/las/network.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
194194
self.eos = label_vec_size - 1 if eos is None else eos
195195
self.max_seq_lens = max_seq_lens
196196
self.num_eos = 3
197+
self.tfr = 1.
197198

198199
Hs, Hc, Hy = rnn_hidden_size, listen_vec_size, label_vec_size
199200

@@ -222,6 +223,9 @@ def get_mask(self, h, seq_lens):
222223
mask[b, seq_lens[b]:] = 0.
223224
return mask
224225

226+
def _is_sample_step(self):
227+
return np.random.random_sample() < self.tfr
228+
225229
def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
226230
batch_size = h.size(0)
227231
sos = int2onehot(h.new_full((batch_size, 1), self.sos), num_classes=self.label_vec_size).float()
@@ -261,9 +265,9 @@ def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
261265
if y_hats_seq_lens.le(t + 1).all():
262266
break
263267

264-
if y is None:
268+
if y is None or not self._is_sample_step(): # non sampling step
265269
x = torch.cat([y_hat, c], dim=-1)
266-
elif t < y.size(1): # teach force
270+
elif t < y.size(1): # scheduled sampling step
267271
x = torch.cat([y.narrow(1, t, 1), c], dim=-1)
268272
else:
269273
x = torch.cat([eos, c], dim=-1)
@@ -335,7 +339,6 @@ def __init__(self, label_vec_size=p.NUM_CTC_LABELS, listen_vec_size=256,
335339
self.eos = self.label_vec_size - 1
336340

337341
self.num_heads = num_attend_heads
338-
self.tfr = 1.
339342

340343
self.listen = Listener(listen_vec_size=listen_vec_size, input_folding=input_folding, rnn_type=nn.LSTM,
341344
rnn_hidden_size=listen_vec_size, rnn_num_layers=4, bidirectional=True,
@@ -357,9 +360,6 @@ def forward(self, x, x_seq_lens, y=None, y_seq_lens=None):
357360
else:
358361
return self._eval_forward(x, x_seq_lens)
359362

360-
def _is_teacher_force(self):
361-
return np.random.random_sample() < self.tfr
362-
363363
def _train_forward(self, x, x_seq_lens, y, y_seq_lens):
364364
# to remove the case of x_seq_lens < y_seq_lens and y_seq_lens > max_seq_lens
365365
bi = x_seq_lens.gt(y_seq_lens) * y_seq_lens.lt(self.spell.max_seq_lens)
@@ -376,15 +376,11 @@ def _train_forward(self, x, x_seq_lens, y, y_seq_lens):
376376
ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=self.blk)
377377
ys, ys_seq_lens = ys[bi], y_seq_lens[bi] + self.spell.num_eos
378378

379-
if self._is_teacher_force():
380-
# speller with teach force rate including noise
381-
floor = np.random.random_sample() * 1e-2
382-
yss = int2onehot(ys, num_classes=self.label_vec_size, floor=floor).float()
383-
noise = torch.rand_like(yss) * 0.1
384-
yss = yss * noise
385-
y_hats, y_hats_seq_lens, self.attentions = self.spell(h, x_seq_lens, yss, ys_seq_lens)
386-
else:
387-
y_hats, y_hats_seq_lens, self.attentions = self.spell(h, x_seq_lens)
379+
floor = np.random.random_sample() * 0.1
380+
yss = int2onehot(ys, num_classes=self.label_vec_size, floor=floor).float()
381+
noise = torch.rand_like(yss) * 0.1
382+
yss = F.softmax(yss * noise, dim=-1)
383+
y_hats, y_hats_seq_lens, self.attentions = self.spell(h, x_seq_lens, yss, ys_seq_lens)
388384

389385
# add regions to attentions
390386
self.regions = torch.IntTensor([(frames - 1, labels - 1) for frames, labels in zip(x_seq_lens, ys_seq_lens)])

asr/models/las/train.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,26 @@
2020
from ..trainer import *
2121
from .network import TFRScheduler, ListenAttendSpell
2222

23+
#from .loss import EditDistanceLoss
24+
2325

2426
class LASTrainer(NonSplitTrainer):
2527
"""Trainer for ListenAttendSpell model"""
2628

2729
def __init__(self, *args, **kwargs):
2830
super().__init__(*args, **kwargs)
2931

30-
self.loss = nn.NLLLoss(reduction='mean', ignore_index=self.model.blk)
32+
self.loss = nn.NLLLoss(reduction='none', ignore_index=self.model.blk)
33+
#self.loss = EditDistanceLoss()
3134

32-
self.tfr_scheduler = TFRScheduler(self.model, ranges=(0.9, 0.0), warm_up=0, epochs=9, restart=True)
35+
self.tfr_scheduler = TFRScheduler(self.model.spell, ranges=(0.9, 0.0), warm_up=0, epochs=9, restart=True)
3336
#self.tfr_scheduler.step(9)
3437
if self.states is not None and "tfr_scheduler" in self.states:
3538
self.tfr_scheduler.load_state_dict(self.states["tfr_scheduler"])
3639

3740
def train_loop_before_hook(self):
3841
self.tfr_scheduler.step()
39-
logger.debug(f"current tfr = {self.model.tfr:.3e}")
42+
logger.debug(f"current tfr = {self.model.spell.tfr:.3e}")
4043

4144
def train_loop_checkpoint_hook(self):
4245
self.plot_attention_heatmap()
@@ -68,8 +71,9 @@ def unit_train(self, data):
6871
ys_hat = ys_hat.float()
6972
if self.use_cuda:
7073
ys_lens = ys_lens.cuda()
71-
loss = self.loss(ys_hat.transpose(1, 2), ys.long())
72-
#loss = self.loss(ys_hat.transpose(1, 2), ys.long()).sum(dim=-1).div(ys_lens.float()).mean()
74+
#loss = self.loss(ys_hat.transpose(1, 2), ys.long())
75+
loss = self.loss(ys_hat.transpose(1, 2), ys.long()).sum(dim=-1).div(ys_lens.float()).mean()
76+
#loss = self.loss(ys_hat, ys.long(), ys_hat_lens, ys_lens)
7377
#if ys_hat_lens is None:
7478
# logger.debug("the batch includes a data with label_lens > max_seq_lens: ignore the entire batch")
7579
# loss.mul_(0)
@@ -210,13 +214,13 @@ def batch_train(argv):
210214
#if i < 2:
211215
# trainer.train_epoch(dataloaders["train3"])
212216
# trainer.validate(dataloaders["dev"])
213-
if i < 10:
217+
if i < 30:
214218
trainer.train_epoch(dataloaders["warmup5"])
215219
trainer.validate(dataloaders["dev"])
216-
elif i < 20:
220+
elif i < 50:
217221
trainer.train_epoch(dataloaders["warmup10"])
218222
trainer.validate(dataloaders["dev"])
219-
elif i < 30:
223+
elif i < 60:
220224
trainer.train_epoch(dataloaders["train10"])
221225
trainer.validate(dataloaders["dev"])
222226
else:

0 commit comments

Comments
 (0)