Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def _flash_attn_varlen_forward_fake(
paged_kv = block_table is not None
batch_size = cu_seqlens_q.numel() - 1
total_q, num_heads, _ = q.shape

out = torch.empty_like(q)
softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
Expand Down Expand Up @@ -252,8 +251,7 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
(
dq,
Expand Down Expand Up @@ -281,7 +279,7 @@ def _flash_attn_backward(
None,
rng_state,
)
return softmax_d
return dq.clone(), dk.clone(), dv.clone(), softmax_d


@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
Expand All @@ -304,18 +302,17 @@ def _flash_attn_backward_fake(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
if dq is None:
dq = torch.empty_like(q)
if dk is None:
dk = torch.empty_like(k)
if dv is None:
dv = torch.empty_like(v)
batch_size, seqlen_q, num_heads, _ = q.shape
softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)

return softmax_d
if torch.cuda.is_available() and torch.version.hip:
softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32)
else:
softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
# dq, dk, dv are already allocated in the fwd pass
# we are passing them here to match the cpp signature and help torch.compile in infering shape during tracing
# without this torch.compile will struggels infer the shape of softmax_d
return dq, dk, dv, softmax_d


if torch.__version__ >= "2.4.0":
Expand Down Expand Up @@ -348,8 +345,7 @@ def _flash_attn_varlen_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
(
dq,
Expand Down Expand Up @@ -382,9 +378,8 @@ def _flash_attn_varlen_backward(
None,
rng_state,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return softmax_d
# return clones else torch.compile will about mutated tensors being returned
return dq.clone(), dk.clone(), dv.clone(), softmax_d


@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
Expand All @@ -411,20 +406,22 @@ def _flash_attn_varlen_backward_fake(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
batch_size = cu_seqlens_q.numel() - 1
total_q, num_heads, _ = q.shape

if dq is None:
dq = torch.empty_like(q)
if dk is None:
dk = torch.empty_like(k)
if dv is None:
dv = torch.empty_like(v)
softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)

return softmax_d

# The CUDA kernel appears to round up max_seqlen_q to a multiple of 128
if torch.cuda.is_available() and torch.version.hip:
softmax_d = torch.empty((batch_size, num_heads, max_seqlen_q), device=q.device, dtype=torch.float32)
else:
softmax_d = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128)), device=q.device, dtype=torch.float32)

# dq, dk, dv are already allocated in the fwd pass
# we are passing them here to match the cpp signature and help torch.compile in infering shape during tracing
# without this torch.compile will struggels infer the shape of softmax_d
return dq, dk, dv, softmax_d


if torch.__version__ >= "2.4.0":
Expand Down
12 changes: 8 additions & 4 deletions tests/test_flash_attn_ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
],
)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("compiled", [False, True])
def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, compiled
):
device = "cuda"
# set seed
Expand Down Expand Up @@ -343,7 +344,8 @@ def test_flash_attn_output(
return_attn_probs=True,
)
else:
out, lse, S_dmask = flash_attn_func(
flash_func = torch.compile(flash_attn_func) if compiled else flash_attn_func
out, lse, S_dmask = flash_func(
q,
k,
v,
Expand Down Expand Up @@ -519,8 +521,9 @@ def test_flash_attn_output(
],
)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("compiled", [False, True])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, compiled
):
device = "cuda"
# set seed
Expand Down Expand Up @@ -598,7 +601,8 @@ def test_flash_attn_varlen_output(
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
flash_varlen_func = torch.compile(flash_attn_varlen_func) if compiled else flash_attn_varlen_func
out_unpad, sm_lse, S_dmask = flash_varlen_func(
q_unpad,
k_unpad,
v_unpad,
Expand Down
Loading