Skip to content

encoder not handling variable length sequences properly  #208

@dsrub

Description

@dsrub

I might be wrong, but I think the seq2seq model in the first notebook does not handle variable length sequences properly (this mistake probably carries over to the other notebooks as well). Specifically, for the encoder, we use the rnn to compute hidden, cell as the summary "context" of the input to initialize the hidden, cell states of the decoder. For a mini-batch, if T is the length of the longest sequence in the mini-batch, then we are running the LSTM in the encoder to compute hidden, cell T steps for all examples. However, the LSTM should be run T1, T2, ... for example 1, example 2 etc... (where T1 is the length of the 1st sequence, etc...).

I think as a simple fix you can use the pack_padded_sequence function in the forward method of the encoder (see below) which I believe computes the hidden/cell states in the fashion that I described. The data loader will also have to provide a tensor of sequence lengths for each example in the batch (see below). Some of the other functions (e.g. the training and eval function) and classes (the seq2seq class) will need to be slightly modified as well to accommodate taking in de_len as an input. I've implemented this and it trains fine for me

def forward(self, src, lens):
      embedded = self.dropout(self.embedding(src))
      packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, lens.cpu().numpy(), enforce_sorted=False, batch_first=False)
      packed_outputs, (hidden, cell) = self.rnn(packed)
      return hidden, cell
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        en_len = torch.tensor([example["en_ids"].shape[0] for example in batch])
        de_len = torch.tensor([example["de_ids"].shape[0] for example in batch])
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
            "en_lens": en_len,
            "de_lens": de_len
        }
        return batch

    return collate_fn

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions