-
Notifications
You must be signed in to change notification settings - Fork 442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about implementation in the multi-head attention part #15
Comments
It may look the same, but unfortunately it's not.
Here is the sample code to help you understand easily. import torch
x = torch.randint(0,3,(1,4,6)) # [batch_size, seq_len, d_model]
print(x)
"""
tensor([[[0, 2, 2, 1, 2, 2],
[2, 0, 0, 2, 0, 1],
[2, 1, 2, 0, 1, 0],
[0, 2, 1, 0, 1, 0]]])
"""
print(x.view(1, 3, 4, 2)) # [batch_size, n_heads, seq_len, (d_model//n_heads)] ==> as you can see below, it is wrong.
"""
tensor([[[[0, 2],
[2, 1],
[2, 2],
[2, 0]],
[[0, 2],
[0, 1],
[2, 1],
[2, 0]],
[[1, 0],
[0, 2],
[1, 0],
[1, 0]]]])
"""
print(x.view(1, 4, 3, 2).transpose(1, 2)) # [batch_size, n_heads, seq_len, (d_model//n_heads)] ==> correctly done.
"""
tensor([[[[0, 2],
[2, 0],
[2, 1],
[0, 2]],
[[2, 1],
[0, 2],
[2, 0],
[1, 0]],
[[2, 2],
[0, 1],
[1, 0],
[1, 0]]]])
""" |
Hi, everyone: As all we know, the matrix saved in memory is always a one-D array even if it has high dimensionality. eg. a 3x4 matrix has 2 dimensions but it saved in memory is a Let us use @goattier's example to go through the transpose progress.
It seems like this structure in memory:
So, how can pytorch (or Numpy) use the index to find the correct member of this matrix? The answer is 'stride'! In this case, we wanna transpose the matrix into (1, 4, 3, 2), so the product has 4 D, which means that it has 4 'axises', let's calculate the stride from the 4th axis.
Then, the transpose means exchanging the axises stride order, in this case, is exchanging axis 1 and axis 2.
You can see the details below: Let's see what if we transform the matrix to (1, 3, 4, 2) directly.
And it is absolutely not as same as the first one. hope this comment can help you : ) |
@AlanYeg, yes, let me fix it |
Hi, I want first to thank you for sharing the repo, and it is very helpful to me to understand the transformer via your code.
I just have one question about your multi-head attention part.
In the forward function, you have
out = x.view(batch_size, seq_len, self.num_heads, d).transpose(1, 2)
I understand the desired output shape should be [batch_size, num_heads, seq_len, d]. But we can do
out = x.view(batch_size, self.num_heads,, seq_len, d)
without using the transpose function.
Is there any particular reason we need to reshape and then transpose it?
Thanks
The text was updated successfully, but these errors were encountered: