Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This groups the scaled dot-product attention logic under a single function, including all the masking stuff, this should be more manageable and make it easier to replace the implementation in the future.
I extracted this from an attempt at flash-attention, which itself didn't turn out worthwhile.
Flash attention
Attention is defined as$softmax(\frac{QK^T}{\sqrt{d}})V$ , the direct implementation requires meterializing $QK^T$ into memory to compute row-wise softmax. This is requires much memory and on the GPU a lot of transfer between fast on-chip SRAM and off-chip HBM (which is actually the bottleneck time-wise). Flash attention is a technique where we chunk all of $Q, K, V$ into blocks and iteratively compute smaller dot products, we just need to do some scaling along the way. An effective implementation requires control over memory access and what part of work the GPU thread does. It has been shown to improve computation speed, exactly because it reduces the slow memory IO.
I implemented flash attention using while loops, similar to Jax attempts (paper, more code). But from my tests on the GPU with actual models it doesn't really translate to meaningful improvements. Time-wise it is slower. It does allow a longer sequence input (like 2x), but given that it's slower, I'd say it makes more sense to just use smaller batch size and regular attention. This is not exactly surprising, because we are at XLA's mercy as to what GPU kernels are generated, and as far as I understand it wouldn't convert
while
loop into CUDAfor
. Also, this may be relevant:Given that the Jax implementation hasn't been adopted in hf/transformers for Flax yet, it suggests that's not good enough. For PyTorch both hf/transformers and xFormers use flash attention implementation that dispatches to specialized CUDA kernels. As for Jax, there is an ongoing/experimental effort for a semi-high-level DSL for writing custom kernels to overcome XLA limitations - Pallas. It is similar to Triton, but meant to be a bit more generic, such that it is lowered to either Triton (for GPU) or Mosaic (for TPU); it is also more high-level since it allows for using most of the usual Jax operations that are lowered automatically. Jax already has experimental flash attention implementation using Pallas, but it's still a work in progress (ref) and it seems like the direction that things are going.
Notes from relevant papers
Flash-Attention (https://arxiv.org/pdf/2205.14135.pdf) - GPU memory consists of fast on-chip SRAM and slower off-chip HBM. HBM access is the bottleneck in the basic attention implementation, in wall-clock sense. The paper presents a GPU-level algorithm for reducing HBM access. The algorithm splits all of the Q K V tensors into chunks and does a nested loop (in each thread), for every K, V chunk it does a pass over Q chunks and updates the corresponding output chunks. The main focus is soft-max, which for a single query depends on all of K; this can be avoided by computing softmax for a chunk and then iteratively rescaling it and updating it as new K chunks are processed. Note that the algorithm is rather low-level, since the proposed implementation uses CUDA with direct memory control, also it is parameterized by SRAM size M.
Flash-Attention-2 (https://arxiv.org/pdf/2307.08691.pdf) - improves on top of Flash-Attention. The modified algorithm reduces the number of non-matmul flops, by moving one output scaling opreation from loop body to a single scaling at the end. It also swaps the nested loops, such that the outer-loop (over query chunks) is embarrassingly parallel, hence it can be used to parallelize the work in addition to batch size and attention heads (both of which are usually small in cases where long sequence length is used). Again, this paper discusses implementation on CUDA level.
Note: swapping the loops order is somewhat a prerequisite for the deferred scaling. With that change, the final scaling can be done after the inner loop. Otherwise, it would have to be after both loops, which means either loading the whole output into SRAM or doing another pass over all chunks and both would be very counterproductive.
JAX memory-efficient attention (https://arxiv.org/pdf/2112.05682.pdf) - this paper discusses similar ideas as 1. and 2., but on a higher-level. It provides a JAX implementation of memory-efficient attention.
Note: the paper focuses on TPU, and also doesn't show speed benefits, only memory improvements.
Flash-Decoding (https://pytorch.org/blog/flash-decoding) - this approach builds on top of Flash-Attention-2, but focuses on inference, specifically autoregressive decoding. It highlights that in autoregressive decoding, subsequent passes work with a single query, and the main bottleneck is reading K,V chunks from the memory (in a loop), while the GPU usage is very small. The improved algorithm first splits K, V into a couple of groups, within each group it splits K, V further into chunks and iteratively computes result exactly as Flash-Attention-2. The groups however are computed in parallel, on separate GPU threads, and once they finish, the results are combined. This allows to increase the GPU usage.