Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Installation][non-reproducible]: Op Flash Attention #17

Open
antferdom opened this issue Oct 22, 2024 · 12 comments
Open

[Installation][non-reproducible]: Op Flash Attention #17

antferdom opened this issue Oct 22, 2024 · 12 comments

Comments

@antferdom
Copy link

antferdom commented Oct 22, 2024

The current project repository assumes existing submodules directory for all the optional dependencies. The Python installation script executes checkout_submodules but that’s again only relevant if submodule dir is populated accordingly (git submodule add). The following represents the expected .gitmodules:

[submodule "submodules/FBGEMM"]
	path = submodules/FBGEMM
	url = https://github.com/pytorch/fbgemm
[submodule "submodules/flash-attention"]
	path = submodules/flash-attention
	url = https://github.com/Dao-AILab/flash-attention.git
[submodule "submodules/ThunderKittens"]
	path = submodules/ThunderKittens
	url = https://github.com/HazyResearch/ThunderKittens.git
[submodule "submodules/cutlass-kernels"]
	path = submodules/cutlass-kernels
	url = https://github.com/ColfaxResearch/cutlass-kernels.git
[submodule "submodules/generative-recommenders"]
	path = submodules/generative-recommenders
	url = https://github.com/facebookresearch/generative-recommenders.git
[submodule "submodules/kernels"]
	path = submodules/kernels
	url = https://github.com/triton-lang/kernels.git
[submodule "submodules/cutlass"]
	path = submodules/cutlass
	url = https://github.com/NVIDIA/cutlass.git

Then we execute git submodule update --init --recursive.

Dockerfile

Two final steps of the Docker image building are commented because they reference bash installation scripts from a non-existing .ci directory, presumably available in the Meta's internal repo. The commented Dockerfile:

# Tritonbench library build and test require libcuda.so.1
# which is from NVIDIA driver
RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 patchelf

# Install Tritonbench
#RUN cd /workspace/tritonbench && \
#    bash .ci/tritonbench/install.sh


# Test Tritonbench
#RUN cd /workspace/tritonbench && \
#    bash .ci/tritonbench/test-install.sh

# Remove NVIDIA driver library - they are supposed to be mapped at runtime
RUN sudo apt-get purge -y libnvidia-compute-550

# Clone the pytorch env as triton-main env, then compile triton main from source
RUN cd /workspace/tritonbench && \
    BASE_CONDA_ENV=${CONDA_ENV} CONDA_ENV=${CONDA_ENV_TRITON_MAIN}
    # bash .ci/tritonbench/install-triton-main.sh

Instead, we execute the container in interactive mode (docker exec -it), and based on the comment of that script, build Triton from source.

@xuzhao9

@xuzhao9
Copy link
Contributor

xuzhao9 commented Oct 24, 2024

Thanks for reporting the issue. Sorry that there are some sync issues going on and we are working on fixing it.

@antferdom
Copy link
Author

Is there anything I can do to help?

@xuzhao9
Copy link
Contributor

xuzhao9 commented Oct 25, 2024

@antferdom Can you help verify the problem is fixed by 748a883 ?

@antferdom
Copy link
Author

@xuzhao9 sure, I will try it out tomorrow

@antferdom
Copy link
Author

antferdom commented Oct 28, 2024

@xuzhao9

  • Commit cebd61d modified several run.py imports, thus now or we install pytorch.benchmark as additional dependency, or it won’t run (see from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS)
  • Correct Docker image building installing all the optional dependencies in the git submodules.
  • Tritonbench is not installed as a Python module in the current commit.
  • Xformers is not built from source. This is crucial if we want to leverage the xformers backend with the Torch custom op wrapper logic for FlasthAttention v3
  • Building Triton from source instal-triton-main.sh leads to TMA errors (e.g run.py –op flash_attention –mode fwd). Removing the installation and installing it with the latest nightly release, solves the issue. Nevertheless, the execution of flash_attention op halts indefinitely after the first iteration, without printing any error. Although Torch compilation could be a potential reason for this long execution time, even after 1h, still on the first benchmarking iteration.

op: flash_attention

(pytorch) runner@compiler-study-hopper:/workspace/tritonbench$ python run.py --op flash_attention --mode fwd
TMA benchmarks will be running without grid constant TMA descriptor.
 11%|████████████████▉                                                                                                                                       | 1/9 [00:10<01:25, 10.67s/it]
  • ThunderKittens installation wrong shared object path:
In [2]: from tritonbench.utils.loader import load_library

In [3]: load_library("tk/tk_attn_h100_fwd.so")
---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
Cell In[3], line 1
----> 1 load_library("tk/tk_attn_h100_fwd.so")

File /workspace/tritonbench/tritonbench/utils/loader.py:9, in load_library(library_path)
      7 prefix, _delimiter, so_file = library_path.partition("/")
      8 so_full_path = REPO_PATH.joinpath(prefix, ".data", so_file).resolve()
----> 9 torch.ops.load_library(str(so_full_path))

File ~/miniconda3/envs/pytorch/lib/python3.11/site-packages/torch/_ops.py:1357, in _Ops.load_library(self, path)
   1352 path = _utils_internal.resolve_library_path(path)
   1353 with dl_open_guard():
   1354     # Import the shared library into the process, thus running its
   1355     # static (global) initialization code in order to register custom
   1356     # operators with the JIT.
-> 1357     ctypes.CDLL(path)
   1358 self.loaded_libraries.add(path)

File ~/miniconda3/envs/pytorch/lib/python3.11/ctypes/__init__.py:376, in CDLL.__init__(self, name, mode, handle, use_errno, use_last_error, winmode)
    373 self._FuncPtr = _FuncPtr
    375 if handle is None:
--> 376     self._handle = _dlopen(self._name, mode)
    377 else:
    378     self._handle = handle

OSError: /workspace/tritonbench/tritonbench/tk/.data/tk_attn_h100_fwd.so: cannot open shared object file: No such file or directory

correct path -> /workspace/tritonbench/utils/tk/.data/tk_attn_h100_fwd.so, there seems to be an additional tritonbench. Using this full path instead of the utility function for loading so:

In [1]: import torch

In [2]: torch.ops.load_library("/workspace/tritonbench/utils/tk/.data/tk_attn_h100_fwd.so")

In [3]:  tk_fwd = torch.ops.tk

Would be possible to access to an existing proven to run Docker container?

@xuzhao9
Copy link
Contributor

xuzhao9 commented Oct 28, 2024

  • Commit cebd61d modified several run.py imports, thus now or we install pytorch.benchmark as additional dependency, or it won’t run (see from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS)
  • Correct Docker image building installing all the optional dependencies in the git submodules.
  • Tritonbench is not installed as a Python module in the current commit.

The above three errors should be fixed by 7e60e23.

Create #20 to track the progress.

  • Building Triton from source instal-triton-main.sh leads to TMA errors (e.g run.py –op flash_attention –mode fwd). Removing the installation and installing it with the latest nightly release, solves the issue. Nevertheless, the execution of flash_attention op halts indefinitely after the first iteration, without printing any error. Although Torch compilation could be a potential reason for this long execution time, even after 1h, still on the first benchmarking iteration.

op: flash_attention

(pytorch) runner@compiler-study-hopper:/workspace/tritonbench$ python run.py --op flash_attention --mode fwd
TMA benchmarks will be running without grid constant TMA descriptor.
 11%|████████████████▉                                                                                                                                       | 1/9 [00:10<01:25, 10.67s/it]

In the docker image, we are using different conda environments to manage the Triton versions:
The pytorch env uses the triton version built-in with the latest pytorch nightly release.
The triton-main env uses the Triton main branch.
Can you please try running the command with the triton-main env? If the problem still exists, please create another issue to track.

  • ThunderKittens installation wrong shared object path:
In [2]: from tritonbench.utils.loader import load_library

In [3]: load_library("tk/tk_attn_h100_fwd.so")
---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
Cell In[3], line 1
----> 1 load_library("tk/tk_attn_h100_fwd.so")

File /workspace/tritonbench/tritonbench/utils/loader.py:9, in load_library(library_path)
      7 prefix, _delimiter, so_file = library_path.partition("/")
      8 so_full_path = REPO_PATH.joinpath(prefix, ".data", so_file).resolve()
----> 9 torch.ops.load_library(str(so_full_path))

File ~/miniconda3/envs/pytorch/lib/python3.11/site-packages/torch/_ops.py:1357, in _Ops.load_library(self, path)
   1352 path = _utils_internal.resolve_library_path(path)
   1353 with dl_open_guard():
   1354     # Import the shared library into the process, thus running its
   1355     # static (global) initialization code in order to register custom
   1356     # operators with the JIT.
-> 1357     ctypes.CDLL(path)
   1358 self.loaded_libraries.add(path)

File ~/miniconda3/envs/pytorch/lib/python3.11/ctypes/__init__.py:376, in CDLL.__init__(self, name, mode, handle, use_errno, use_last_error, winmode)
    373 self._FuncPtr = _FuncPtr
    375 if handle is None:
--> 376     self._handle = _dlopen(self._name, mode)
    377 else:
    378     self._handle = handle

OSError: /workspace/tritonbench/tritonbench/tk/.data/tk_attn_h100_fwd.so: cannot open shared object file: No such file or directory

correct path -> /workspace/tritonbench/utils/tk/.data/tk_attn_h100_fwd.so, there seems to be an additional tritonbench. Using this full path instead of the utility function for loading so:

In [1]: import torch

In [2]: torch.ops.load_library("/workspace/tritonbench/utils/tk/.data/tk_attn_h100_fwd.so")

In [3]:  tk_fwd = torch.ops.tk

I am working on this, will test and have a PR soon.
#22 should fix this.

@antferdom
Copy link
Author

In the docker image, we are using different conda environments to manage the Triton versions:
The pytorch env uses the triton version built-in with the latest pytorch nightly release.
The triton-main env uses the Triton main branch.
Can you please try running the command with the triton-main env? If the problem still exists, please create another issue to track.

  • Triton built from source TMA descriptor problem in FA2 implementation:
(triton-main) runner@compiler-study-hopper:/workspace/tritonbench$ python run.py --op flash_attention --mode fwd
TMA benchmarks will be running without grid constant TMA descriptor.
  0%|                                                                                                                                                                | 0/9 [00:01<?, ?it/s]
Caught exception, terminating early with partial results
Traceback (most recent call last):
  File "/workspace/tritonbench/tritonbench/utils/triton_op.py", line 708, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
                                                  ^^^^^^^^^^^^^^^^^
  File "/workspace/tritonbench/tritonbench/utils/triton_op.py", line 696, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
                   ^^^^^^^^^^^^^^^
  File "/workspace/tritonbench/tritonbench/utils/triton_op.py", line 915, in _do_bench
    metrics.latency = triton.testing.do_bench(
                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/testing.py", line 106, in do_bench
    fn()
  File "/workspace/tritonbench/tritonbench/operators/flash_attention/operator.py", line 253, in <lambda>
    return lambda: triton_tutorial_FA2_tma(q, k, v, self.causal, self.sm_scale)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/tritonbench/tritonbench/kernels/triton_fused_attention.py", line 1147, in forward
    _attn_fwd_tma[grid_tma](
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 133, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/testing.py", line 106, in do_bench
    fn()
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 114, in kernel_call
    self.fn.run(
  File "/home/runner/miniconda3/envs/triton-main/lib/python3.11/site-packages/triton/runtime/jit.py", line 683, in run
    grid = grid(bound_args)
           ^^^^^^^^^^^^^^^^
  File "/workspace/tritonbench/tritonbench/kernels/triton_fused_attention.py", line 1086, in grid_tma
    desc_helper.fill_2d_tma_descriptor(
  File "/workspace/tritonbench/tritonbench/kernels/triton_fused_attention.py", line 101, in fill_2d_tma_descriptor
    self.fill_2d_tma_descriptor_inner(
TypeError: a bytes-like object is required, not 'int'
  (Batch, Heads, SeqLen, Dhead)

Environment information

Collecting environment information...
PyTorch version: 2.6.0.dev20241028+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.5
Libc version: glibc-2.35

Python version: 3.11.10 (main, Oct  3 2024, 07:29:13) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 550.90.12
cuDNN version: Probably one of the following:
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn.so.9.1.0
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn_adv.so.9.1.0
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn_cnn.so.9.1.0
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn_engines_precompiled.so.9.1.0
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn_graph.so.9.1.0
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn_heuristic.so.9.1.0
/usr/local/cuda-12.4/targets/x86_64-linux/lib/libcudnn_ops.so.9.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               30
On-line CPU(s) list:                  0-29
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                           6
Model:                                143
Thread(s) per core:                   1
Core(s) per socket:                   1
Socket(s):                            30
Stepping:                             8
BogoMIPS:                             5600.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            960 KiB (30 instances)
L1i cache:                            960 KiB (30 instances)
L2 cache:                             120 MiB (30 instances)
L3 cache:                             480 MiB (30 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-29
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241028+cu124
[pip3] triton==3.0.0+git8cdba567
[conda] magma-cuda124             2.6.1                         1    pytorch
[conda] numpy                     2.1.2                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pytorch-triton            3.1.0+cf34004b8a          pypi_0    pypi
[conda] torch                     2.6.0.dev20241028+cu124          pypi_0    pypi
[conda] triton                    3.0.0+git8cdba567          pypi_0    pypi

@xuzhao9
Copy link
Contributor

xuzhao9 commented Oct 28, 2024

For the flash-attention operator there are two problems:

  1. triton_fused_attention kernel does not work for triton-main: we do not have a plan to fix that soon (soon means maybe in one or two weeks), as there might be many other kernels that don't work for triton-main at this time, too. We plan to first build the CI and systematically find out how many of the existing kernels that don't work for triton-main. Especially, we need PT2-team to get involved if there is incompatibility issues between Torchinductor and Triton-main.

  2. flash_attention on pytorch built-in Triton should work ootb and is a problem if any of them hang, I will take a look at this.

@antferdom
Copy link
Author

antferdom commented Oct 29, 2024

  1. triton_fused_attention kernel does not work for triton-main: Understood, sounds like a good formal procedure to isolate non working kernels with triton-main
  2. It works, but after the first iteration is completed (reported by tqdm) the execution halts indefinitely (let it running in the mentioned environment for around 1h). GPU utilisation is at maximum capacity, but there isn't any memory alteration. I will try to isolate and go one by one attention implementation in a standalone script without the operator registry logic of Tritonbench -> test_attn

Will be pending of the open PR, but both the TK import pathing problem and xformers, look good to me.

@xuzhao9
Copy link
Contributor

xuzhao9 commented Oct 29, 2024

For 2), you could use --only to specify a single attention impl, this should be easier to bisect/isolate.

facebook-github-bot pushed a commit that referenced this issue Oct 29, 2024
Summary:
As mentioned in #17

Pull Request resolved: #22

Test Plan:
```
(base) ➜  tritonbench git:(xz9/fix-tk-load) python -c "from tritonbench.utils.loader import load_library; load_library('tk/tk_attn_h100_fwd.so')"
(base) ➜  tritonbench git:(xz9/fix-tk-load) echo $?
0
```

Reviewed By: FindHao

Differential Revision: D65147532

Pulled By: xuzhao9

fbshipit-source-id: 669b2aa4db9b0581a2dbf709e9b577dbbe5e670e
@antferdom
Copy link
Author

True, I will use the --only to isolate one implementation step by step.

@antferdom antferdom changed the title [Installation][non-reproducible]: Missing git submodules [Installation][non-reproducible]: Op Flash Attention Oct 29, 2024
@antferdom
Copy link
Author

Using the latest Tritonbench Docker image, I was finally able to fully run the flash_attention operator:

/workspace/tritonbench$ python run.py --op flash_attention --mode fwd --precision bf16
TMA benchmarks will be running without grid constant TMA descriptor.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:05<00:00,  7.25s/it]
  (Batch, Heads, SeqLen, Dhead)    sdpa-latency    aten-latency    flash_v2-latency    flash_v3-latency    triton_tutorial_flash_v2-latency    triton_tutorial_flash_v2_tma-latency    flex_attention-latency
-------------------------------  --------------  --------------  ------------------  ------------------  ----------------------------------  --------------------------------------  ------------------------
              (32, 32, 512, 64)        0.262816        2.88429             0.253344            0.212928                            0.204864
             (16, 32, 1024, 64)        0.483936        5.64954             0.461056            0.360672                            0.371744
              (8, 32, 2048, 64)        0.928736       13.2899              0.900224            0.621792                            0.685088
              (4, 32, 4096, 64)        1.81328        23.6284              1.74131             1.18486                             1.38013
              (2, 32, 8192, 64)        3.59718        45.2292              3.44819             2.19107                             2.62384
             (1, 32, 16384, 64)        7.15782                             6.85792             4.28938                             5.21392                                                           5.20269
               (4, 32, 19, 128)        0.009248        0.032416            0.01328             0.012032                            0.00832                                                           0.099776
                (4, 32, 1, 128)        0.009056        0.025888            0.013888            0.01152                             0.007904                                                          0.098656
              (4, 32, 511, 128)        0.066976        0.54688             0.065312            0.048288                            0.060288                                                          0.082976

Nevertheless, many attention implementations are registered as enabled=False:

  • xformers
  • xformers_splitk
  • colfax_cutlass
  • cudnn
  • thunderkittens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants