Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What's the difference of flash attention implement between cudnn and Dao-AILab? #52

Open
MoFHeka opened this issue Dec 20, 2023 · 19 comments

Comments

@MoFHeka
Copy link

MoFHeka commented Dec 20, 2023

Is this link a flash attention?

@gautam20197
Copy link

There is no difference in algorithm and numerics of cudnn and Dao-AILab. The implementation in cudnn benefits from the in-house expertise at kernel development in Nvidia and aims to maximize the hardware capabilities.

Please use the following samples as code snippets to use cudnn flash attention:
CPP : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/samples/cpp/mha.cpp
Python : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/samples/python/test_mhas.py
Documentation : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/docs/operations/Attention.md

@MoFHeka
Copy link
Author

MoFHeka commented Dec 21, 2023

@gautam20197 As far as I know, flash attention has been implemented by nvidia in tensorflow, right?
cuda_dnn.cc

@Cjkkkk
Copy link

Cjkkkk commented Jan 2, 2024

@MoFHeka , it is not correct to say it is implemented in tensorflow, it is implemented in XLA and there is a PR openxla/xla#6872 pending to integrate the final piece of flash attention in XLA. Once this PR is merged, you can access flash attention from JAX/Tensorflow if the pattern is supported.

@MoFHeka
Copy link
Author

MoFHeka commented Jan 5, 2024

@Cjkkkk So if I understand correctly, in addition to TF/Jax, Pytorch can also use OpenXla to work with cudnn.

@mnicely
Copy link
Collaborator

mnicely commented Feb 21, 2024

@MoFHeka PyTorch eager mode has a path to cuDNN's optimized attention.

@mnicely
Copy link
Collaborator

mnicely commented Feb 21, 2024

I think we've addressed the original question. Going to close for now

@mnicely mnicely closed this as completed Feb 21, 2024
@MoFHeka
Copy link
Author

MoFHeka commented Mar 22, 2024

Is there any benchmark between CuDNN fused attention and flash attention? Recently I found TorchACC has already supported using CuDNN fused attention in PyTorch training. So there's definitely a benchmark, right? Even a C++ code end-to-end performance. @mnicely @Cjkkkk @gautam20197

I am eager to know how I should align whether the acceleration after I turn on xla has reached the ideal state.

@mnicely
Copy link
Collaborator

mnicely commented Mar 22, 2024

I think you can check your use case using the PyTorch nightlies.
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

And running the PyTorch SDPA example https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html

Using TORCH_CUDNN_SDPA_ENABLED=1

@MoFHeka
Copy link
Author

MoFHeka commented Mar 22, 2024

@mnicely Thank you very much for your answer. May I ask how much improvement has been made compared to Dao-AILab flash attention 2 according to your evaluation?

@mnicely
Copy link
Collaborator

mnicely commented Mar 22, 2024

We recently release cuDNN V9.
FP16 and BF16 fused flash attention engine performance has been significantly improved for NVIDIA GPUs:

  • Speed-up of up to 50% over cuDNN 8.9.7 on Hopper GPUs.
  • Speed-up of up to 100% over cuDNN 8.9.7 on Ampere GPUs.

We say up to because it depends on the parameters.

@MoFHeka
Copy link
Author

MoFHeka commented Mar 24, 2024

@mnicely I have noticed that speed-up benchmark at cudnn release note recently. Yes, it looks perfect. But is there any more details for QKV shape and something else.
A single acceleration in a particular situation is not convincing enough, we need a repeatable experiment scenario.

@Anerudhan Anerudhan reopened this Mar 25, 2024
@gautam20197
Copy link

@MoFHeka The problem sizes with hidden dimension per head (d) = 128 are the best to gain a significant speedup for both Hopper and Ampere.

@MoFHeka
Copy link
Author

MoFHeka commented Mar 25, 2024

@gautam20197 head (d) = 128 with any batch size or sequence length?

@gautam20197
Copy link

Yes there will be healthy speedup for all batches and sequence lengths.

@MoFHeka
Copy link
Author

MoFHeka commented Mar 26, 2024

I think you can check your use case using the PyTorch nightlies. pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

And running the PyTorch SDPA example https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html

Using TORCH_CUDNN_SDPA_ENABLED=1

@mnicely Unable to run

>>> import torch
>>> t1=torch.randn(1,4,4096,128).to("cuda").to(torch.float16)
>>> torch._scaled_dot_product_cudnn_attention(t1, t1, t1, dropout_p=0.0, is_causal=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1711266070736/work/aten/src/ATen/native/cudnn/MHA.cpp":410, please report a bug to PyTorch.

@arogozhnikov
Copy link

arogozhnikov commented May 22, 2024

@mnicely torch 2.3, but still unable to use with TORCH_CUDNN_SDPA_ENABLED=1

import torch
q, k, v = torch.randn(3, 1, 4, 4096, 128).to("cuda").to(torch.bfloat16)
torch._scaled_dot_product_cudnn_attention(q, k, v)

RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans built successfully.

I'm not exactly sure I should post it here, but this issue and this repo look most suitable to me

@mnicely
Copy link
Collaborator

mnicely commented May 22, 2024

@MoFHeka @arogozhnikov can you both try again with the latest nightlies? The following should work

pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124

with

import torch
q, k, v = torch.randn(3, 1, 4, 4096, 128).to("cuda").to(torch.bfloat16)
torch._scaled_dot_product_cudnn_attention(q, k, v)

@arogozhnikov
Copy link

@mnicely unfortunately it still does not work with nightlies. A100, 40GB, TORCH_CUDNN_SDPA_ENABLED=1

error is the same: RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans built successfully.

@MoFHeka
Copy link
Author

MoFHeka commented May 28, 2024

@mnicely I have tested cudnn attention in A30 with image nvcr.io/nvidia/pytorch:24.04-py3. it is much slower than flash attention in the same image.

=====================TEST CUDNN Attention=====================
/workspace/qkv_attention.py:34: UserWarning: USING CUDNN SDPA (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:586.)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal)
/workspace/qkv_attention.py:37: UserWarning: USING CUDNN SDPA (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:586.)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal)
(1, 4096, 32, 64) 0.00217504077591002
(1, 4096, 16, 128) 0.0020994041115045547
(1, 8192, 32, 64) 0.00814421707764268
(1, 8192, 16, 128) 0.007756144041195512
(1, 16384, 32, 64) 0.03211699100211263
(1, 16384, 16, 128) 0.024100096197798848
(1, 32768, 32, 64) 0.09487044904381037
(1, 32768, 16, 128) 0.09436172991991043
(1, 65536, 32, 64) 0.3823277521878481
(1, 65536, 16, 128) 0.3809503731317818
(1, 131072, 32, 64) 1.544325471157208
(1, 131072, 16, 128) 1.5516349140089005
=====================TEST Flash Attention=====================
(1, 4096, 32, 64) 0.001407111994922161
(1, 4096, 16, 128) 0.0012988371308892965
(1, 8192, 32, 64) 0.005206326022744179
(1, 8192, 16, 128) 0.0054926639422774315
(1, 16384, 32, 64) 0.023661984829232097
(1, 16384, 16, 128) 0.02126245992258191
(1, 32768, 32, 64) 0.10068793897517025
(1, 32768, 16, 128) 0.08452382893301547
(1, 65536, 32, 64) 0.3777510579675436
(1, 65536, 16, 128) 0.33767699101008475
(1, 131072, 32, 64) 1.5188782080076635
(1, 131072, 16, 128) 1.3448062930256128

Are there any method to improve performance?

import os

import torch                                                                                                                                                    
import time
os.environ['TORCH_CUDNN_SDPA_ENABLED'] = '1'

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
    try:
        from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
    except ImportError:
        flash_attn_unpadded_func = None
try:
    from einops import rearrange
except ImportError:
    rearrange = None

causal = False

print("=====================TEST CUDNN Attention=====================")
b = 1
s = 4096
count = 0
torch.cuda.empty_cache()
while True:
    for h, d in zip((32, 16), (64, 128)):
        q, k, v = torch.randn(b, s, h*d*3, dtype=torch.bfloat16, device='cuda', requires_grad=True).chunk(3, dim=-1)
        q = q.view(b, -1, h, d).transpose(1, 2)
        k = k.view(b, -1, h, d).transpose(1, 2)
        v = v.view(b, -1, h, d).transpose(1, 2)
        with torch.no_grad():
            for i in range(5):
                out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal)
            torch.cuda.synchronize()
            t1 = time.perf_counter()
            out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal)
            torch.cuda.synchronize()
            t2 = time.perf_counter()
        print(f"{(b, s, h, d)} {t2-t1}")

    if b > 1:
        b //= 2
    s *= 2
    if s > 131072:
       break

print("=====================TEST Flash Attention=====================")
b = 1
s = 4096
count = 0
torch.cuda.empty_cache()
while True:
    for h, d in zip((32, 16), (64, 128)):
        q, k, v = torch.randn(b, s, h, 3*d, dtype=torch.bfloat16, device='cuda', requires_grad=True).chunk(3, dim=-1)
        batch_size, seqlen_q = q.shape[0], q.shape[1]
        seqlen_k = k.shape[1]
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
                                    device=q.device)
        cu_seqlens_k = cu_seqlens_q
        with torch.no_grad():
            for i in range(5):
                out = flash_attn_unpadded_func(
                        q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
                        0.0, causal=causal
                    )
            torch.cuda.synchronize()
            t1 = time.perf_counter()
            out = flash_attn_unpadded_func(
                    q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
                    0.0, causal=causal
                )
            torch.cuda.synchronize()
            t2 = time.perf_counter()
        print(f"{(b, s, h, d)} {t2-t1}")

    if b > 1:
        b //= 2
    s *= 2
    if s > 131072:
       break

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants