Skip to content

Commit

Permalink
use max(abs(())
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 4, 2025
1 parent 1ed8bbb commit 38b1b46
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 53 deletions.
26 changes: 12 additions & 14 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ def _bwd_kernel_one_col_block(
# compute dv
if IS_FP8:
# compute descale_p
p_max = tl.max(p_drop_scaled)
p_max = tl.where(p_max <= 1e-9, 1e-9, p_max)
scale_p = FP8_MAX / p_max
descale_p = p_max / FP8_MAX
p_amax = tl.max(tl.abs(p_drop_scaled))
p_amax = tl.where(p_amax <= 1e-9, 1e-9, p_amax)
scale_p = FP8_MAX / p_amax
descale_p = p_amax / FP8_MAX

# NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user
dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do)
Expand All @@ -267,10 +267,10 @@ def _bwd_kernel_one_col_block(
# compute dv
if IS_FP8:
# compute descale_p
p_max = tl.max(p)
p_max = tl.where(p_max <= 1e-9, 1e-9, p_max)
scale_p = FP8_MAX / p_max
descale_p = p_max / FP8_MAX
p_amax = tl.max(tl.abs(p))
p_amax = tl.where(p_amax <= 1e-9, 1e-9, p_amax)
scale_p = FP8_MAX / p_amax
descale_p = p_amax / FP8_MAX

# NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user
dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do)
Expand All @@ -292,15 +292,13 @@ def _bwd_kernel_one_col_block(
dscores_scaled = (p * (dp - delta_i[:, None]))
ds = dscores_scaled * sm_scale
ds = tl.where(p_mask, ds, 0.0)

print("ds:", ds)

# compute descale_ds
if IS_FP8:
ds_max = tl.max(ds)
ds_max = tl.where(ds_max <= 1e-9, 1e-9, ds_max)
scale_ds = FP8_MAX / ds_max
descale_ds = ds_max / FP8_MAX
ds_amax = tl.max(tl.abs(ds)) # NOTE: ds can be negative so if we get a negative max value. It will screw things up.
ds_amax = tl.where(ds_amax <= 1e-9, 1e-9, ds_amax)
scale_ds = FP8_MAX / ds_amax
descale_ds = ds_amax / FP8_MAX
else:
scale_ds, descale_ds = 1.0, 1.0

Expand Down
8 changes: 4 additions & 4 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri

if IS_FP8:
# compute descale_p
p_max = tl.max(p)
p_max = tl.where(p_max <= 1e-9, 1e-9, p_max)
scale_p = FP8_MAX / p_max
descale_p = p_max / FP8_MAX
p_amax = tl.max(tl.abs(p))
p_amax = tl.where(p_amax <= 1e-9, 1e-9, p_amax)
scale_p = FP8_MAX / p_amax
descale_p = p_amax / FP8_MAX

# NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user
acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v)
Expand Down
69 changes: 34 additions & 35 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,41 +739,41 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16):
@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
[
# (1, 1, 1, 1, 1, 1),
(1, 1, 1, 1, 1, 1),
(1, 1, 1, 2, 2, 16),
# (1, 1, 1, 2, 4, 16),
# (1, 2, 2, 2, 4, 16),
# (1, 4, 1, 2, 4, 16),
# (1, 4, 2, 2, 4, 16),
# (1, 1, 1, 4, 2, 16),
# (1, 1, 1, 4, 4, 16),
# (1, 2, 2, 4, 4, 16),
# (2, 1, 1, 4, 4, 16),
# (2, 2, 2, 4, 4, 16),
# (1, 1, 1, 128, 64, 16),
# (2, 2, 2, 2, 128, 1),
# (2, 3, 3, 2, 128, 16),
# (3, 2, 2, 256, 512, 16),
# (3, 3, 3, 128, 128, 64),
# (2, 4, 4, 1024, 1024, 64),
# (4, 6, 6, 108, 256, 224),
# (4, 8, 8, 2048, 2048, 128),
# (4, 16, 16, 4096, 4096, 64),
# (2, 4, 4, 8192, 8192, 32),
# # fa configs
# (4, 6, 1, 113, 203, 256),
# (4, 6, 1, 128, 217, 256),
# (4, 6, 2, 113, 211, 128),
# (4, 6, 2, 108, 256, 128),
# (4, 6, 1, 256, 512, 64),
# (4, 6, 1, 512, 256, 64),
# (4, 6, 2, 1024, 1024, 32),
# (4, 6, 2, 1023, 1024, 32),
# (4, 6, 6, 1024, 1023, 32),
# (4, 6, 6, 2048, 2048, 32),
(1, 1, 1, 2, 4, 16),
(1, 2, 2, 2, 4, 16),
(1, 4, 1, 2, 4, 16),
(1, 4, 2, 2, 4, 16),
(1, 1, 1, 4, 2, 16),
(1, 1, 1, 4, 4, 16),
(1, 2, 2, 4, 4, 16),
(2, 1, 1, 4, 4, 16),
(2, 2, 2, 4, 4, 16),
(1, 1, 1, 128, 64, 16),
(2, 2, 2, 2, 128, 1),
(2, 3, 3, 2, 128, 16),
(3, 2, 2, 256, 512, 16),
(3, 3, 3, 128, 128, 64),
(2, 4, 4, 1024, 1024, 64),
(4, 6, 6, 108, 256, 224),
(4, 8, 8, 2048, 2048, 128),
(4, 16, 16, 4096, 4096, 64),
(2, 4, 4, 8192, 8192, 32),
# fa configs
(4, 6, 1, 113, 203, 256),
(4, 6, 1, 128, 217, 256),
(4, 6, 2, 113, 211, 128),
(4, 6, 2, 108, 256, 128),
(4, 6, 1, 256, 512, 64),
(4, 6, 1, 512, 256, 64),
(4, 6, 2, 1024, 1024, 32),
(4, 6, 2, 1023, 1024, 32),
(4, 6, 6, 1024, 1023, 32),
(4, 6, 6, 2048, 2048, 32),
],
)
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('DEBUG_INPUT', [False])
@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device")
Expand Down Expand Up @@ -939,7 +939,7 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
],
)
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('dropout_p', [0.0, 0.25])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('DEBUG_INPUT', [False])
@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device")
def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT):
Expand Down Expand Up @@ -1017,7 +1017,6 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
descale_q = q_maxes / type_max
descale_k = k_maxes / type_max
descale_v = v_maxes / type_max
descale_p = torch.full_like(descale_q, 1.0 / type_max, dtype=torch.float32, device=q.device)
descale_do = do_maxes / type_max

# scale tensors to fp8 range
Expand Down Expand Up @@ -1068,7 +1067,7 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_p=None,
descale_do=descale_do
)

Expand Down

0 comments on commit 38b1b46

Please sign in to comment.