-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What's the difference of flash attention implement between cudnn and Dao-AILab? #52
Comments
There is no difference in algorithm and numerics of cudnn and Dao-AILab. The implementation in cudnn benefits from the in-house expertise at kernel development in Nvidia and aims to maximize the hardware capabilities. Please use the following samples as code snippets to use cudnn flash attention: |
@gautam20197 As far as I know, flash attention has been implemented by nvidia in tensorflow, right? |
@MoFHeka , it is not correct to say it is implemented in tensorflow, it is implemented in XLA and there is a PR openxla/xla#6872 pending to integrate the final piece of flash attention in XLA. Once this PR is merged, you can access flash attention from JAX/Tensorflow if the pattern is supported. |
@Cjkkkk So if I understand correctly, in addition to TF/Jax, Pytorch can also use OpenXla to work with cudnn. |
@MoFHeka PyTorch eager mode has a path to cuDNN's optimized attention. |
I think we've addressed the original question. Going to close for now |
Is there any benchmark between CuDNN fused attention and flash attention? Recently I found TorchACC has already supported using CuDNN fused attention in PyTorch training. So there's definitely a benchmark, right? Even a C++ code end-to-end performance. @mnicely @Cjkkkk @gautam20197 I am eager to know how I should align whether the acceleration after I turn on xla has reached the ideal state. |
I think you can check your use case using the PyTorch nightlies. And running the PyTorch SDPA example https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html Using |
@mnicely Thank you very much for your answer. May I ask how much improvement has been made compared to Dao-AILab flash attention 2 according to your evaluation? |
We recently release cuDNN V9.
We say up to because it depends on the parameters. |
@mnicely I have noticed that speed-up benchmark at cudnn release note recently. Yes, it looks perfect. But is there any more details for QKV shape and something else. |
@MoFHeka The problem sizes with hidden dimension per head (d) = 128 are the best to gain a significant speedup for both Hopper and Ampere. |
@gautam20197 head (d) = 128 with any batch size or sequence length? |
Yes there will be healthy speedup for all batches and sequence lengths. |
@mnicely Unable to run >>> import torch
>>> t1=torch.randn(1,4,4096,128).to("cuda").to(torch.float16)
>>> torch._scaled_dot_product_cudnn_attention(t1, t1, t1, dropout_p=0.0, is_causal=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1711266070736/work/aten/src/ATen/native/cudnn/MHA.cpp":410, please report a bug to PyTorch. |
@mnicely torch 2.3, but still unable to use with
RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans built successfully. I'm not exactly sure I should post it here, but this issue and this repo look most suitable to me |
@MoFHeka @arogozhnikov can you both try again with the latest nightlies? The following should work
with
|
@mnicely unfortunately it still does not work with nightlies. A100, 40GB, error is the same: RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans built successfully. |
@mnicely I have tested cudnn attention in A30 with image nvcr.io/nvidia/pytorch:24.04-py3. it is much slower than flash attention in the same image. =====================TEST CUDNN Attention===================== Are there any method to improve performance? import os
import torch
import time
os.environ['TORCH_CUDNN_SDPA_ENABLED'] = '1'
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
try:
from einops import rearrange
except ImportError:
rearrange = None
causal = False
print("=====================TEST CUDNN Attention=====================")
b = 1
s = 4096
count = 0
torch.cuda.empty_cache()
while True:
for h, d in zip((32, 16), (64, 128)):
q, k, v = torch.randn(b, s, h*d*3, dtype=torch.bfloat16, device='cuda', requires_grad=True).chunk(3, dim=-1)
q = q.view(b, -1, h, d).transpose(1, 2)
k = k.view(b, -1, h, d).transpose(1, 2)
v = v.view(b, -1, h, d).transpose(1, 2)
with torch.no_grad():
for i in range(5):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal)
torch.cuda.synchronize()
t1 = time.perf_counter()
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal)
torch.cuda.synchronize()
t2 = time.perf_counter()
print(f"{(b, s, h, d)} {t2-t1}")
if b > 1:
b //= 2
s *= 2
if s > 131072:
break
print("=====================TEST Flash Attention=====================")
b = 1
s = 4096
count = 0
torch.cuda.empty_cache()
while True:
for h, d in zip((32, 16), (64, 128)):
q, k, v = torch.randn(b, s, h, 3*d, dtype=torch.bfloat16, device='cuda', requires_grad=True).chunk(3, dim=-1)
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
cu_seqlens_k = cu_seqlens_q
with torch.no_grad():
for i in range(5):
out = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0, causal=causal
)
torch.cuda.synchronize()
t1 = time.perf_counter()
out = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0, causal=causal
)
torch.cuda.synchronize()
t2 = time.perf_counter()
print(f"{(b, s, h, d)} {t2-t1}")
if b > 1:
b //= 2
s *= 2
if s > 131072:
break |
Is this link a flash attention?
The text was updated successfully, but these errors were encountered: