You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After the attention layers were splitted into all nodes I missed the implications what it introduced.
Long story short: to calculate the attention for a single head from the Q output, I need to have the whole head from the K output. For x Q head I need to have whole floor(x / (nHeads / nKvHeads)) K head to calculate the result.
For example Llama 3 8B:
💡 dim: 128
💡 nHeads: 32
💡 nKvHeads: 8
Q head 0 => floor( 0 / ( 32 / 8) ) => K head 0
Q head 1 => floor( 1 / ( 32 / 8) ) => K head 0
Q head 2 => floor( 2 / ( 32 / 8) ) => K head 0
...
Q head 8 => floor( 8 / ( 32 / 8) ) => K head 2
Q head 9 => floor( 9 / ( 32 / 8) ) => K head 2
...
Q head 31 => floor( 31 / ( 32 / 8) ) => K head 7
By this currently is not possible to split nodes to more than nKvHeads nodes.
^ The same problem is with the V layer.
How this could be fixed?
1. Synchronize missing outputs
For nSlices > nKvHeads setups there could be introduced a new synchronization step. This step would synchornize missing Q/V outputs across nodes. Ofc the synchronization is the slowest part of Distributed Llama.
2. Redundancy
The redundancy could be introduces for K/V layers. These layers should be splited with the aligment to headSize. By this there is no synchronization, and redundant amount of calculations seems to be small (headSize - kvDim0).
For example Llama 3 8B:
headSize = dim / nHeads = 128
kvDim = (dim * kvHeads) / nHeads = 1024
nSlices = 16
kvDim0 = kvDim / nSlices = 64
redundancy = 128 - 64 = 64 outputs of K & V
nSlices = 32
kvDim0 = kvDim / nSlices = 32
redundancy = 128 - 32 = 96 outputs of K & V
The text was updated successfully, but these errors were encountered:
After the attention layers were splitted into all nodes I missed the implications what it introduced.
Long story short: to calculate the attention for a single head from the Q output, I need to have the whole head from the K output. For
x
Q head I need to have wholefloor(x / (nHeads / nKvHeads))
K head to calculate the result.For example Llama 3 8B:
By this currently is not possible to split nodes to more than
nKvHeads
nodes.^ The same problem is with the V layer.
How this could be fixed?
1. Synchronize missing outputs
For
nSlices > nKvHeads
setups there could be introduced a new synchronization step. This step would synchornize missing Q/V outputs across nodes. Ofc the synchronization is the slowest part of Distributed Llama.2. Redundancy
The redundancy could be introduces for K/V layers. These layers should be splited with the aligment to
headSize
. By this there is no synchronization, and redundant amount of calculations seems to be small (headSize - kvDim0
).For example Llama 3 8B:
The text was updated successfully, but these errors were encountered: