We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
https://github.com/feifeibear/long-context-attention/blob/0.6.0/yunchang/kernels/attention.py#L47 torch版本时torch 2.3 您好,我实验发现对于 下面的case batch_size = 1 num_heads = 2 head_dim = 128 seq_len = 16
torch.ops.aten._scaled_dot_product_efficient_attention 的lse返回值有bug 这里建议使用 _scaled_dot_product_flash_attention这个函数,返回lse
import torch batch_size = 1 num_heads = 2 head_dim = 128 seq_len = 16
dtype = torch.float16 device = 'cuda'
torch.manual_seed(42) query = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) key = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) value = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
out1 = torch.nn.functional.scaled_dot_product_attention(query,key,value)
out2, lse2 = torch.ops.aten._scaled_dot_product_flash_attention(query, key, value)[:2] print(f'lse2: {lse2}')
out3, lse3 = torch.ops.aten._scaled_dot_product_efficient_attention(query, key,value, attn_bias=None, compute_log_sumexp=True)[:2] print(f'lse3: {lse3}')
print(f'Result: {torch.allclose(out1, out2, rtol=1e-3, atol=1e-3)}') print(f'Result: {torch.allclose(out1, out3, rtol=1e-3, atol=1e-3)}') print(f'Result: {torch.allclose(lse2, lse3, rtol=1e-3, atol=1e-3)}')
结果: lse2: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459, 5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473], [5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015, 5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732]]], device='cuda:0')
lse3: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459, 5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf], [5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015, 5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf]]], device='cuda:0') Result: True Result: True
The text was updated successfully, but these errors were encountered:
https://github.com/feifeibear/long-context-attention/blob/0.6.0/yunchang/kernels/attention.py#L47,这一行可以改为如下: out, lse = torch.ops.aten._scaled_dot_product_flash_attention(q.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3), dropout_p=dropout_p, is_causal=causal, scale=softmax_scale, )[:2]
Sorry, something went wrong.
感谢 @neonhuang !您能交一个 MR 么?如果 torch 版本<2.3 执行你粘贴的代码?
No branches or pull requests
https://github.com/feifeibear/long-context-attention/blob/0.6.0/yunchang/kernels/attention.py#L47
torch版本时torch 2.3
您好,我实验发现对于 下面的case
batch_size = 1
num_heads = 2
head_dim = 128
seq_len = 16
torch.ops.aten._scaled_dot_product_efficient_attention 的lse返回值有bug
这里建议使用 _scaled_dot_product_flash_attention这个函数,返回lse
import torch
batch_size = 1
num_heads = 2
head_dim = 128
seq_len = 16
dtype = torch.float16
device = 'cuda'
torch.manual_seed(42)
query = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
key = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
value = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
out1 = torch.nn.functional.scaled_dot_product_attention(query,key,value)
out2, lse2 = torch.ops.aten._scaled_dot_product_flash_attention(query, key, value)[:2]
print(f'lse2: {lse2}')
out3, lse3 = torch.ops.aten._scaled_dot_product_efficient_attention(query, key,value,
attn_bias=None, compute_log_sumexp=True)[:2]
print(f'lse3: {lse3}')
print(f'Result: {torch.allclose(out1, out2, rtol=1e-3, atol=1e-3)}')
print(f'Result: {torch.allclose(out1, out3, rtol=1e-3, atol=1e-3)}')
print(f'Result: {torch.allclose(lse2, lse3, rtol=1e-3, atol=1e-3)}')
结果:
lse2: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459,
5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473],
[5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015,
5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732]]],
device='cuda:0')
lse3: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459,
5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473,
inf, inf, inf, inf, inf, inf, inf, inf,
inf, inf, inf, inf, inf, inf, inf, inf],
[5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015,
5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732,
inf, inf, inf, inf, inf, inf, inf, inf,
inf, inf, inf, inf, inf, inf, inf, inf]]],
device='cuda:0')
Result: True
Result: True
The text was updated successfully, but these errors were encountered: