Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

encoder not handling variable length sequences properly #208

Open
dsrub opened this issue Sep 23, 2024 · 0 comments
Open

encoder not handling variable length sequences properly #208

dsrub opened this issue Sep 23, 2024 · 0 comments

Comments

@dsrub
Copy link

dsrub commented Sep 23, 2024

I might be wrong, but I think the seq2seq model in the first notebook does not handle variable length sequences properly (this mistake probably carries over to the other notebooks as well). Specifically, for the encoder, we use the rnn to compute hidden, cell as the summary "context" of the input to initialize the hidden, cell states of the decoder. For a mini-batch, if T is the length of the longest sequence in the mini-batch, then we are running the LSTM in the encoder to compute hidden, cell T steps for all examples. However, the LSTM should be run T1, T2, ... for example 1, example 2 etc... (where T1 is the length of the 1st sequence, etc...).

I think as a simple fix you can use the pack_padded_sequence function in the forward method of the encoder (see below) which I believe computes the hidden/cell states in the fashion that I described. The data loader will also have to provide a tensor of sequence lengths for each example in the batch (see below). Some of the other functions (e.g. the training and eval function) and classes (the seq2seq class) will need to be slightly modified as well to accommodate taking in de_len as an input. I've implemented this and it trains fine for me

def forward(self, src, lens):
      embedded = self.dropout(self.embedding(src))
      packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, lens.cpu().numpy(), enforce_sorted=False, batch_first=False)
      packed_outputs, (hidden, cell) = self.rnn(packed)
      return hidden, cell
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        en_len = torch.tensor([example["en_ids"].shape[0] for example in batch])
        de_len = torch.tensor([example["de_ids"].shape[0] for example in batch])
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
            "en_lens": en_len,
            "de_lens": de_len
        }
        return batch

    return collate_fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant