Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about implementation in the multi-head attention part #15

Open
Acatsama0871 opened this issue Mar 28, 2023 · 4 comments
Open

Comments

@Acatsama0871
Copy link

Hi, I want first to thank you for sharing the repo, and it is very helpful to me to understand the transformer via your code.

I just have one question about your multi-head attention part.

In the forward function, you have
out = x.view(batch_size, seq_len, self.num_heads, d).transpose(1, 2)

I understand the desired output shape should be [batch_size, num_heads, seq_len, d]. But we can do
out = x.view(batch_size, self.num_heads,, seq_len, d)
without using the transpose function.

Is there any particular reason we need to reshape and then transpose it?

Thanks

@goattier
Copy link

goattier commented Apr 3, 2023

It may look the same, but unfortunately it's not.
i.e. x.view(batch_size, seq_len, self.num_heads, d).transpose(1, 2) is not equal to x.view(batch_size, self.num_heads, seq_len, d).

x.view(batch_size, self.num_heads, seq_len, d) ruins embedding vectors of the (word) tokens. So if you want to divide QKV matrices into multi-heads without corrupting attention operation during multi-head attention, you should use the first code.

Here is the sample code to help you understand easily.

import torch

x = torch.randint(0,3,(1,4,6)) # [batch_size, seq_len, d_model]
print(x)
"""
tensor([[[0, 2, 2, 1, 2, 2],
         [2, 0, 0, 2, 0, 1],
         [2, 1, 2, 0, 1, 0],
         [0, 2, 1, 0, 1, 0]]])
"""


print(x.view(1, 3, 4, 2)) # [batch_size, n_heads, seq_len, (d_model//n_heads)] ==> as you can see below, it is wrong.
"""
tensor([[[[0, 2],
          [2, 1],
          [2, 2],
          [2, 0]],

         [[0, 2],
          [0, 1],
          [2, 1],
          [2, 0]],

         [[1, 0],
          [0, 2],
          [1, 0],
          [1, 0]]]])
"""


print(x.view(1, 4, 3, 2).transpose(1, 2)) # [batch_size, n_heads, seq_len, (d_model//n_heads)] ==> correctly done.
"""
tensor([[[[0, 2],
          [2, 0],
          [2, 1],
          [0, 2]],

         [[2, 1],
          [0, 2],
          [2, 0],
          [1, 0]],

         [[2, 2],
          [0, 1],
          [1, 0],
          [1, 0]]]])
"""

@Trevorcat
Copy link

Trevorcat commented Apr 3, 2023

Hi, everyone:
Thanks @goattier gave a really nice example, and it helps me to find out why would it happen. I think I can show more details about this issue.

As all we know, the matrix saved in memory is always a one-D array even if it has high dimensionality. eg. a 3x4 matrix has 2 dimensions but it saved in memory is a 1x6 array 1x12 array.

Let us use @goattier's example to go through the transpose progress.
The matrix is

[[[0, 2, 2, 1, 2, 2],
  [2, 0, 0, 2, 0, 1],
  [2, 1, 2, 0, 1, 0],
  [0, 2, 1, 0, 1, 0]]]

It seems like this structure in memory:

[0, 2, 2, 1, 2, 2,    2, 0, 0, 2, 0, 1,     2, 1, 2, 0, 1, 0,     0, 2, 1, 0, 1, 0]

So, how can pytorch (or Numpy) use the index to find the correct member of this matrix? The answer is 'stride'!

In this case, we wanna transpose the matrix into (1, 4, 3, 2), so the product has 4 D, which means that it has 4 'axises', let's calculate the stride from the 4th axis.
In the 4th axis, moving to the next element just need one step, so the stride of the axis is 1
In the 3rd axis, moving to the next element need two steps, each of them has 2 elements, so the stride of this axis is 2
By the same logic, the stride of the 2nd axis is 2x3=6
The stride of 1st axis is 2x3x4=24

axis 0   => 24
axis 1   => 6
axis 2   => 2
axis 3   => 1

Then, the transpose means exchanging the axises stride order, in this case, is exchanging axis 1 and axis 2.

axis 0   => 24
axis 1   => 2
axis 2   => 6
axis 3   => 1

You can see the details below:

image

Let's see what if we transform the matrix to (1, 3, 4, 2) directly.
In the 4th axis, moving to the next element just need one step, so the stride of the axis is 1
In the 3rd axis, moving to the next element need two steps, each of them has 2 elements, so the stride of this axis is 2
By the same logic, the stride of the 2nd axis is 2x4=8
The stride of 1st axis is 2x3x4=24

axis 0   => 24
axis 1   => 8
axis 2   => 2
axis 3   => 1

image

And it is absolutely not as same as the first one. hope this comment can help you : )

@Trevorcat
Copy link

@AlanYeg, yes, let me fix it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@goattier @Trevorcat @Acatsama0871 and others