Skip to content

Commit

Permalink
per block descale_p and descale_ds
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 4, 2025
1 parent 249b8b4 commit 1ed8bbb
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 74 deletions.
69 changes: 43 additions & 26 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def _bwd_preprocess_use_o(
stride_descale_q_z = H
descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_h)

# do is scaled in the fp8 range and o is in fp8 but should be the same scale as fp32
# TODO: descale do so that we can use it as fp32
# NOTE: do is scaled into the fp8 range and o is in fp8 but should be in the same scale as fp32
delta = tl.sum(o.to(tl.float32) * (do * descale_do).to(tl.float32), axis=1)
else:
delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1)
Expand Down Expand Up @@ -139,7 +138,6 @@ def _bwd_kernel_one_col_block(
descale_q,
descale_k,
descale_v,
descale_p,
descale_do,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
Expand All @@ -151,6 +149,7 @@ def _bwd_kernel_one_col_block(
USE_EXP2: tl.constexpr,
GROUP_SIZE: tl.constexpr,
IS_FP8: tl.constexpr,
FP8_MAX: tl.constexpr,
):
if CAUSAL:
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
Expand Down Expand Up @@ -246,8 +245,14 @@ def _bwd_kernel_one_col_block(

# compute dv
if IS_FP8:
p_fp8_scaled = p_drop_scaled * (1.0/ descale_p) # put p into fp8 range
dv += (tl.dot(tl.trans(p_fp8_scaled).to(do.type.element_ty), do) * descale_p * descale_do)
# compute descale_p
p_max = tl.max(p_drop_scaled)
p_max = tl.where(p_max <= 1e-9, 1e-9, p_max)
scale_p = FP8_MAX / p_max
descale_p = p_max / FP8_MAX

# NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user
dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do)
else:
dv += tl.dot(tl.trans(p_drop_scaled).to(do.type.element_ty), do)

Expand All @@ -261,8 +266,14 @@ def _bwd_kernel_one_col_block(

# compute dv
if IS_FP8:
p_fp8_scaled = p * (1.0/ descale_p) # put p into fp8 range
dv += (tl.dot(tl.trans(p_fp8_scaled).to(do.type.element_ty), do) * descale_p * descale_do)
# compute descale_p
p_max = tl.max(p)
p_max = tl.where(p_max <= 1e-9, 1e-9, p_max)
scale_p = FP8_MAX / p_max
descale_p = p_max / FP8_MAX

# NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user
dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do)
else:
dv += tl.dot(tl.trans(p).to(do.type.element_ty), do)

Expand All @@ -281,31 +292,34 @@ def _bwd_kernel_one_col_block(
dscores_scaled = (p * (dp - delta_i[:, None]))
ds = dscores_scaled * sm_scale
ds = tl.where(p_mask, ds, 0.0)

print("ds:", ds)

# print("p:", p)
# print("dp:", dp)
# print("delta_i:", delta_i)
# print("ds:", ds) # NOTE:is almost the same between fp8 and fp16
descale_ds = descale_p
ds_fp8_scaled = ds * (1.0/ descale_ds)
# print("ds_fp8_scaled:", ds_fp8_scaled)
# compute descale_ds
if IS_FP8:
ds_max = tl.max(ds)
ds_max = tl.where(ds_max <= 1e-9, 1e-9, ds_max)
scale_ds = FP8_MAX / ds_max
descale_ds = ds_max / FP8_MAX
else:
scale_ds, descale_ds = 1.0, 1.0

# compute dk
if IS_FP8:
dk += (tl.dot(tl.trans(ds_fp8_scaled).to(q.type.element_ty), q) * descale_ds * descale_q)
dk += (tl.dot(tl.trans(ds * scale_ds).to(q.type.element_ty), q) * descale_ds * descale_q)
else:
dk += tl.dot(tl.trans(ds).to(q.type.element_ty), q)

# compute dq
if SEQUENCE_PARALLEL:
if IS_FP8:
dq = (tl.dot(ds_fp8_scaled.to(k.type.element_ty), k) * descale_ds * descale_k)
dq = (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k)
else:
dq = tl.dot(ds.to(k.type.element_ty), k)
else:
dq = tl.load(dq_ptrs, mask=q_mask, other=0.0)
if IS_FP8:
dq += (tl.dot(ds_fp8_scaled.to(k.type.element_ty), k) * descale_ds * descale_k)
dq += (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k)
else:
dq += tl.dot(ds.to(k.type.element_ty), k)
tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask)
Expand Down Expand Up @@ -340,7 +354,6 @@ def _bwd_kernel(
DESCALE_q,
DESCALE_k,
DESCALE_v,
DESCALE_p,
DESCALE_do,
stride_dq_all,
stride_qz,
Expand Down Expand Up @@ -382,6 +395,7 @@ def _bwd_kernel(
IS_VARLEN: tl.constexpr,
GROUP_SIZE: tl.constexpr,
IS_FP8: tl.constexpr,
FP8_MAX: tl.constexpr,
):
# program ids
off_zh = tl.program_id(0)
Expand Down Expand Up @@ -434,10 +448,9 @@ def _bwd_kernel(
descale_q = tl.load(DESCALE_q + off_z * stride_descale_q_z + off_hq)
descale_k = tl.load(DESCALE_k + off_z * stride_descale_kv_z + off_hk)
descale_v = tl.load(DESCALE_v + off_z * stride_descale_kv_z + off_hk)
descale_p = tl.load(DESCALE_p + off_z * stride_descale_q_z + off_hq)
descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_hq)
else:
descale_q, descale_k, descale_v, descale_p, descale_do = 1.0, 1.0, 1.0, 1.0, 1.0
descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0

# output tensor offsets
dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn
Expand Down Expand Up @@ -499,7 +512,6 @@ def _bwd_kernel(
descale_q,
descale_k,
descale_v,
descale_p,
descale_do,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
Expand All @@ -510,7 +522,8 @@ def _bwd_kernel(
DROPOUT=DROPOUT,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE,
IS_FP8=IS_FP8
IS_FP8=IS_FP8,
FP8_MAX=FP8_MAX
)
else:
for start_n in range(0, num_block_n):
Expand Down Expand Up @@ -564,7 +577,6 @@ def _bwd_kernel(
descale_q,
descale_k,
descale_v,
descale_p,
descale_do,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
Expand All @@ -575,7 +587,8 @@ def _bwd_kernel(
DROPOUT=DROPOUT,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE,
IS_FP8=IS_FP8
IS_FP8=IS_FP8,
FP8_MAX=FP8_MAX
)


Expand Down Expand Up @@ -635,6 +648,10 @@ def attention_prefill_backward_triton_impl(
print("philox_offset:", philox_offset)
print("use_exp2:", use_exp2)
print("sequence_parallel:", sequence_parallel)
print("descale_q:", descale_q)
print("descale_k:", descale_k)
print("descale_v:", descale_v)
print("descale_do:", descale_do)

is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
if is_fp8:
Expand Down Expand Up @@ -807,7 +824,6 @@ def attention_prefill_backward_triton_impl(
descale_q,
descale_k,
descale_v,
descale_p,
descale_do,
stride_dq_all,
stride_qz, stride_qh, stride_qm, stride_qk,
Expand Down Expand Up @@ -838,7 +854,8 @@ def attention_prefill_backward_triton_impl(
waves_per_eu = waves_per_eu,
IS_VARLEN=is_varlen,
GROUP_SIZE=group_size,
IS_FP8=is_fp8
IS_FP8=is_fp8,
FP8_MAX=torch.finfo(torch.float8_e4m3fnuz).max
)

if sequence_parallel:
Expand Down
30 changes: 17 additions & 13 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
descale_q, descale_k, descale_v, descale_p, IS_FP8: tl.constexpr,
descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
Expand Down Expand Up @@ -179,8 +179,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
m_i = m_ij

if IS_FP8:
p *= (1.0/ descale_p) # put p into fp8 range
acc += (tl.dot(p.to(v.type.element_ty), v) * descale_p * descale_v)
# compute descale_p
p_max = tl.max(p)
p_max = tl.where(p_max <= 1e-9, 1e-9, p_max)
scale_p = FP8_MAX / p_max
descale_p = p_max / FP8_MAX

# NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user
acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v)
else:
acc += tl.dot(p.to(v.type.element_ty), v)

Expand Down Expand Up @@ -282,7 +288,8 @@ def attn_fwd(Q, K, V, bias,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr):
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr,
IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
Expand Down Expand Up @@ -416,9 +423,8 @@ def attn_fwd(Q, K, V, bias,
descale_q = tl.load(DESCALE_Q + off_z * stride_descale_q_z + off_h_q)
descale_k = tl.load(DESCALE_K + off_z * stride_descale_kv_z + off_h_k)
descale_v = tl.load(DESCALE_V + off_z * stride_descale_kv_z + off_h_k)
descale_p = tl.load(DESCALE_P + off_z * stride_descale_p_z + off_h_q)
else:
descale_q, descale_k, descale_v, descale_p = 1.0, 1.0, 1.0, 1.0
descale_q, descale_k, descale_v = 1.0, 1.0, 1.0

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
Expand All @@ -445,7 +451,7 @@ def attn_fwd(Q, K, V, bias,
sd_mask_ptrs, dropout_mask_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope,
descale_q, descale_k, descale_v, descale_p, IS_FP8,
descale_q, descale_k, descale_v, IS_FP8, FP8_MAX,
# IS_CAUSAL, ....
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
Expand Down Expand Up @@ -473,7 +479,7 @@ def attn_fwd(Q, K, V, bias,
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, descale_p, IS_FP8,
n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, IS_FP8, FP8_MAX,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
Expand Down Expand Up @@ -610,11 +616,10 @@ def attention_prefill_forward_triton_impl(
descale_q_stride_z = descale_q.stride(0)
descale_k_stride_z = descale_k.stride(0)
descale_v_stride_z = descale_v.stride(0)
descale_p_stride_z = descale_p.stride(0)
else:
# For non-FP8 types, use dummy values (no scaling needed)
descale_q = descale_k = descale_v = descale_p = 1
descale_q_stride_z = descale_k_stride_z = descale_v_stride_z = descale_p_stride_z = 0
descale_q_stride_z = descale_k_stride_z = descale_v_stride_z = 0


if DEBUG:
Expand All @@ -626,7 +631,6 @@ def attention_prefill_forward_triton_impl(
print("descale_q_stride_z:", descale_q_stride_z)
print("descale_k_stride_z:", descale_k_stride_z)
print("descale_v_stride_z:", descale_v_stride_z)
print("descale_p_stride_z:", descale_p_stride_z)
if is_fp8:
print(f"type_max: {type_max}")

Expand Down Expand Up @@ -690,15 +694,15 @@ def attention_prefill_forward_triton_impl(


attn_fwd[grid](q, k, v, bias,
descale_q, descale_k, descale_v, descale_p, descale_q_stride_z, descale_k_stride_z, descale_p_stride_z,
descale_q, descale_k, descale_v, None, descale_q_stride_z, descale_k_stride_z, None,
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8)
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8, FP8_MAX=torch.finfo(torch.float8_e4m3fnuz).max)

if DEBUG:
print()
Expand Down
69 changes: 34 additions & 35 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,42 +739,42 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16):
@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
[
(1, 1, 1, 1, 1, 1),
# (1, 1, 1, 1, 1, 1),
(1, 1, 1, 2, 2, 16),
(1, 1, 1, 2, 4, 16),
(1, 2, 2, 2, 4, 16),
(1, 4, 1, 2, 4, 16),
(1, 4, 2, 2, 4, 16),
(1, 1, 1, 4, 2, 16),
(1, 1, 1, 4, 4, 16),
(1, 2, 2, 4, 4, 16),
(2, 1, 1, 4, 4, 16),
(2, 2, 2, 4, 4, 16),
(1, 1, 1, 128, 64, 16),
(2, 2, 2, 2, 128, 1),
(2, 3, 3, 2, 128, 16),
(3, 2, 2, 256, 512, 16),
(3, 3, 3, 128, 128, 64),
(2, 4, 4, 1024, 1024, 64),
(4, 6, 6, 108, 256, 224),
(4, 8, 8, 2048, 2048, 128),
(4, 16, 16, 4096, 4096, 64),
(2, 4, 4, 8192, 8192, 32),
# fa configs
(4, 6, 1, 113, 203, 256),
(4, 6, 1, 128, 217, 256),
(4, 6, 2, 113, 211, 128),
(4, 6, 2, 108, 256, 128),
(4, 6, 1, 256, 512, 64),
(4, 6, 1, 512, 256, 64),
(4, 6, 2, 1024, 1024, 32),
(4, 6, 2, 1023, 1024, 32),
(4, 6, 6, 1024, 1023, 32),
(4, 6, 6, 2048, 2048, 32),
# (1, 1, 1, 2, 4, 16),
# (1, 2, 2, 2, 4, 16),
# (1, 4, 1, 2, 4, 16),
# (1, 4, 2, 2, 4, 16),
# (1, 1, 1, 4, 2, 16),
# (1, 1, 1, 4, 4, 16),
# (1, 2, 2, 4, 4, 16),
# (2, 1, 1, 4, 4, 16),
# (2, 2, 2, 4, 4, 16),
# (1, 1, 1, 128, 64, 16),
# (2, 2, 2, 2, 128, 1),
# (2, 3, 3, 2, 128, 16),
# (3, 2, 2, 256, 512, 16),
# (3, 3, 3, 128, 128, 64),
# (2, 4, 4, 1024, 1024, 64),
# (4, 6, 6, 108, 256, 224),
# (4, 8, 8, 2048, 2048, 128),
# (4, 16, 16, 4096, 4096, 64),
# (2, 4, 4, 8192, 8192, 32),
# # fa configs
# (4, 6, 1, 113, 203, 256),
# (4, 6, 1, 128, 217, 256),
# (4, 6, 2, 113, 211, 128),
# (4, 6, 2, 108, 256, 128),
# (4, 6, 1, 256, 512, 64),
# (4, 6, 1, 512, 256, 64),
# (4, 6, 2, 1024, 1024, 32),
# (4, 6, 2, 1023, 1024, 32),
# (4, 6, 6, 1024, 1023, 32),
# (4, 6, 6, 2048, 2048, 32),
],
)
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('dropout_p', [0.0, 0.25])
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('DEBUG_INPUT', [False])
@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device")
def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT):
Expand Down Expand Up @@ -828,7 +828,6 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
descale_q = q_max / type_max
descale_k = k_max / type_max
descale_v = v_max / type_max
descale_p = torch.full_like(descale_q, 1.0 / type_max, dtype=torch.float32, device=q.device)
descale_do = do_max / type_max

# scale values to fp8 range
Expand All @@ -852,7 +851,7 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_p=None,
descale_do=descale_do
)

Expand Down

0 comments on commit 1ed8bbb

Please sign in to comment.