Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor attention implementation #300

Merged
merged 2 commits into from
Dec 12, 2023
Merged

Refactor attention implementation #300

merged 2 commits into from
Dec 12, 2023

Conversation

jonatanklosko
Copy link
Member

@jonatanklosko jonatanklosko commented Dec 12, 2023

This groups the scaled dot-product attention logic under a single function, including all the masking stuff, this should be more manageable and make it easier to replace the implementation in the future.

I extracted this from an attempt at flash-attention, which itself didn't turn out worthwhile.

Flash attention

Attention is defined as $softmax(\frac{QK^T}{\sqrt{d}})V$, the direct implementation requires meterializing $QK^T$ into memory to compute row-wise softmax. This is requires much memory and on the GPU a lot of transfer between fast on-chip SRAM and off-chip HBM (which is actually the bottleneck time-wise). Flash attention is a technique where we chunk all of $Q, K, V$ into blocks and iteratively compute smaller dot products, we just need to do some scaling along the way. An effective implementation requires control over memory access and what part of work the GPU thread does. It has been shown to improve computation speed, exactly because it reduces the slow memory IO.

I implemented flash attention using while loops, similar to Jax attempts (paper, more code). But from my tests on the GPU with actual models it doesn't really translate to meaningful improvements. Time-wise it is slower. It does allow a longer sequence input (like 2x), but given that it's slower, I'd say it makes more sense to just use smaller batch size and regular attention. This is not exactly surprising, because we are at XLA's mercy as to what GPU kernels are generated, and as far as I understand it wouldn't convert while loop into CUDA for. Also, this may be relevant:

For sequence lengths <= 4096, XLA has some fusions that do something similar to flash attention so the expected improvement over XLA isn't that big. Once you go to 8k and above, you should see much bigger improvements. ~ jax-ml/jax#18590 (comment)

Given that the Jax implementation hasn't been adopted in hf/transformers for Flax yet, it suggests that's not good enough. For PyTorch both hf/transformers and xFormers use flash attention implementation that dispatches to specialized CUDA kernels. As for Jax, there is an ongoing/experimental effort for a semi-high-level DSL for writing custom kernels to overcome XLA limitations - Pallas. It is similar to Triton, but meant to be a bit more generic, such that it is lowered to either Triton (for GPU) or Mosaic (for TPU); it is also more high-level since it allows for using most of the usual Jax operations that are lowered automatically. Jax already has experimental flash attention implementation using Pallas, but it's still a work in progress (ref) and it seems like the direction that things are going.

Notes from relevant papers
  1. Flash-Attention (https://arxiv.org/pdf/2205.14135.pdf) - GPU memory consists of fast on-chip SRAM and slower off-chip HBM. HBM access is the bottleneck in the basic attention implementation, in wall-clock sense. The paper presents a GPU-level algorithm for reducing HBM access. The algorithm splits all of the Q K V tensors into chunks and does a nested loop (in each thread), for every K, V chunk it does a pass over Q chunks and updates the corresponding output chunks. The main focus is soft-max, which for a single query depends on all of K; this can be avoided by computing softmax for a chunk and then iteratively rescaling it and updating it as new K chunks are processed. Note that the algorithm is rather low-level, since the proposed implementation uses CUDA with direct memory control, also it is parameterized by SRAM size M.

  2. Flash-Attention-2 (https://arxiv.org/pdf/2307.08691.pdf) - improves on top of Flash-Attention. The modified algorithm reduces the number of non-matmul flops, by moving one output scaling opreation from loop body to a single scaling at the end. It also swaps the nested loops, such that the outer-loop (over query chunks) is embarrassingly parallel, hence it can be used to parallelize the work in addition to batch size and attention heads (both of which are usually small in cases where long sequence length is used). Again, this paper discusses implementation on CUDA level.

    Note: swapping the loops order is somewhat a prerequisite for the deferred scaling. With that change, the final scaling can be done after the inner loop. Otherwise, it would have to be after both loops, which means either loading the whole output into SRAM or doing another pass over all chunks and both would be very counterproductive.

    Most modern GPUs contain specialized units to accelerate matrix multiply in low-precision (e.g., Tensor Cores on Nvidia GPUs for FP16/BF16 matrix multiply)

  3. JAX memory-efficient attention (https://arxiv.org/pdf/2112.05682.pdf) - this paper discusses similar ideas as 1. and 2., but on a higher-level. It provides a JAX implementation of memory-efficient attention.

    Note: the paper focuses on TPU, and also doesn't show speed benefits, only memory improvements.

    Dao et al. (2022) provide a CUDA implementation of memory-efficient attention and demonstrate that the reduced memory requirements can translate to significant speedups on GPUs. One reason why we do not observe the same performance gains in this paper is that standard self-attention already balances the available FLOPs and memory bandwidth of TPUs.

  4. Flash-Decoding (https://pytorch.org/blog/flash-decoding) - this approach builds on top of Flash-Attention-2, but focuses on inference, specifically autoregressive decoding. It highlights that in autoregressive decoding, subsequent passes work with a single query, and the main bottleneck is reading K,V chunks from the memory (in a loop), while the GPU usage is very small. The improved algorithm first splits K, V into a couple of groups, within each group it splits K, V further into chunks and iteratively computes result exactly as Flash-Attention-2. The groups however are computed in parallel, on separate GPU threads, and once they finish, the results are combined. This allows to increase the GPU usage.

@jonatanklosko jonatanklosko merged commit eca3735 into main Dec 12, 2023
2 checks passed
@jonatanklosko jonatanklosko deleted the jk-attention branch December 12, 2023 07:13
@jonatanklosko
Copy link
Member Author

I just noticed there is a Flax implementation in hf/diffusers, but this comment confirms that the benefits are not exactly significant (and again it's TPU).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants