-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
I might be wrong, but I think the seq2seq model in the first notebook does not handle variable length sequences properly (this mistake probably carries over to the other notebooks as well). Specifically, for the encoder, we use the rnn to compute hidden, cell as the summary "context" of the input to initialize the hidden, cell states of the decoder. For a mini-batch, if T is the length of the longest sequence in the mini-batch, then we are running the LSTM in the encoder to compute hidden, cell T steps for all examples. However, the LSTM should be run T1, T2, ... for example 1, example 2 etc... (where T1 is the length of the 1st sequence, etc...).
I think as a simple fix you can use the pack_padded_sequence function in the forward method of the encoder (see below) which I believe computes the hidden/cell states in the fashion that I described. The data loader will also have to provide a tensor of sequence lengths for each example in the batch (see below). Some of the other functions (e.g. the training and eval function) and classes (the seq2seq class) will need to be slightly modified as well to accommodate taking in de_len as an input. I've implemented this and it trains fine for me
def forward(self, src, lens):
embedded = self.dropout(self.embedding(src))
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, lens.cpu().numpy(), enforce_sorted=False, batch_first=False)
packed_outputs, (hidden, cell) = self.rnn(packed)
return hidden, celldef get_collate_fn(pad_index):
def collate_fn(batch):
batch_en_ids = [example["en_ids"] for example in batch]
batch_de_ids = [example["de_ids"] for example in batch]
en_len = torch.tensor([example["en_ids"].shape[0] for example in batch])
de_len = torch.tensor([example["de_ids"].shape[0] for example in batch])
batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
batch = {
"en_ids": batch_en_ids,
"de_ids": batch_de_ids,
"en_lens": en_len,
"de_lens": de_len
}
return batch
return collate_fn