-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance compared to local-attention + masking #120
Comments
Apologies in advance if I misunderstood your question; I think you're suggesting that we implement neighborhood attention by just modifying the attention mask logic in FMHA/FAV2 and use the existing sliding window parameter? If so, the short answer is we kind of are, but that only covers a very small portion of our use cases. Long answer: The main difference between implementations of sliding window attention in FAv2/FMHA and NATTEN is that we're more interested in multi-dimensional inputs (i.e. 2-D and 3-D feature maps representing image and video spaces.) Put simply, neighborhood attention is to self attention what convolution is to fully-connected layers (think nn.Linear on the spatial axes as well as channels). Similarly it allows for parameters like dilation (but padding and stride don't really make sense.) From an implementation point of view, NATTEN considers the attention problem as 2 GETTs (General Tensor-Tensor Contraction) instead of 2 GEMMs (General Matrix-Matrix multiply). This basically means we're trying to not make an assumption about the "row mode" being single dimensional (a sequence of tokens). It is true that implementing 1-D neighborhood attention into FAv2/FMHA is relatively trivial (I'll note NA also supports dilation, but there's an easy hack to supporting dilation.) However, 2-D and 3-D require major modifications to convert the GEMMs in FAv2/FMHA to GETTs, and that's kind of what we're after right now. Our fused kernels are based on FMHA; and our 1-D kernels are relatively similar in latency compared to FMHA. That does not hold for 2D and 3D, because those are now GETT problems, and that complicates the data iteration process some. Speaking mostly for the forward pass kernel, most of the additional latency would come from that. So we could probably accelerate 1-D neighborhood attention by making those small changes to FAv2, but it would fall short of supporting 2-D and 3-D problems. What we're really hoping to allow with NATTEN in the future is not just multi-dimensional sliding windows, but enabling explicit spatio-temporal attention (one attention call, with causal masking across certain dimensions, and without across others.) |
It is possible; I just very seriously doubt it'll have any advantages compared to FNA.
That diagram is showing 1D, 2D, and 3D. The idea is that the kernel tiling is changed according to the spatial rank, but the threadblock-level MMA is agnostic to that. Load and store is mostly what's affected.
So technically you can implement all forms of neighborhood attention with attention masks; there's no doubt there.
As you pointed out, materializing the attention mask kind of defeats the purpose of dropping the memory footprint with a fused kernel, and I'm pretty sure you can't get rid of the quadratic computation, because it's not as straightforward to modify the 1-D sliding window to replicate 2-D or 3-D. I actually haven't given this much thought recently, but my guess is that it won't be possible. Also consider that dilation might make that even more complicated. The sliding window aspect of NA makes these things difficult, but the behavior around the corners just makes it worse. In addition to the extra memory overhead, reading an explicit attention mask from global memory will probably slow down an attention kernel more than the overhead that 2D and 3D NA add because of the GEMM -> GETT changes. So just to clarify; 2D and 3D are not as performant as 1D in terms of bandwidth; However, that doesn't mean FNA2d is so much slower that an explicit attention mask would be faster. It might be for a few small edge cases, but in general I think the surest way to reduce actual global memory reads and reduce FLOPs, and of course ensure its correctness, is to convert the problem into a GETT like convs.
The explicit attention mask would be my guess. Even if it's just |
You are right; with a little bit of modification to a 1-D attention kernel, we should be able to cut off most of what's outside the diagonal region. I'm not so sure if it'll support dilation (because dilation would require holding more than the one stride value, and that forces the GEMM -> GETT conversion, so we'd wind up with FNA2d again), but we'll set that aside for now.
Yes (kind of). It's all still subject to what your GEMM shape / tile size is. |
Thanks for the clarifications! I understand why FNA2d is implemented the way it is, but I'd nevertheless be curious about the performance compared to a local-attention. In particular, I think the "materializing attention mask" issue can be resolved if you compute the attention mask within the kernel itself (thus never materializing in gmem). So, a good baseline comparison might be comparing FNA2d vs. "full" local-attention. It is not a full 1:1 comparison, because FNA2d can take advantage of some additional sparsity while "full" local self-attention doesn't need to handle the additional index computations for the mask, but might be a good baseline to compare against. |
That's fair; it is something I'm curious about myself. I guess I can take a look at this once we wrap up the backward kernel for FNA.
Could you clarify this? If we're modifying 1-D attention kernel to do 2-D NA with attention masking, it would still have to do extra indexing computation and checks (and the checks are more likely to contribute to latency than index computation because of branching). By the way, I really appreciate all the feedback. God knows I'm not going to get it from conference reviews haha. |
I mean that an "easy"-ish comparison to do would be to run a local/sliding window attention kernel and compare against FNA2d. It won't exactly match the semantics of FNA2d, but it would give an upper-bound of how efficient a local/sliding-window attention kernel would be compared to FNA2d.
Haha, I think Natten is quite cool :) We've been thinking about how to handle more of these kinds of attention "variants" in PyTorch, and Natten is a pretty nice example. |
Okay, yeah so I did try a few problem sizes before we released, but most of them were comparing really large kernel sizes. I'll definitely come up with a set of problem sizes and test 2D and 3D against 1D very soon.
Oh wow, thank you 🙂 . Means a great deal coming from you. |
I saw #89
Can you elaborate some more on what the performance differences are between Natten and local/sliding-window attention? I understand that Natten is not a special-case of local/sliding window attention, but from my understanding, you should be able to implement natten using 1. local/sliding-window attention, and 2. a special attention mask.
Is that correct? In that case, I think you would have most of the performance benefits of natten (i.e. you skip much of the unnecessary computation).
The text was updated successfully, but these errors were encountered: