From 38b1b46a0745ec20aa60fccedf65145332e09e6b Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 4 Feb 2025 12:45:52 -0800 Subject: [PATCH] use max(abs(()) --- .../flash_attn_triton_amd/bwd_prefill.py | 26 ++++--- .../flash_attn_triton_amd/fwd_prefill.py | 8 +-- flash_attn/flash_attn_triton_amd/test.py | 69 +++++++++---------- 3 files changed, 50 insertions(+), 53 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index e8462e02a..38f79bca2 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -246,10 +246,10 @@ def _bwd_kernel_one_col_block( # compute dv if IS_FP8: # compute descale_p - p_max = tl.max(p_drop_scaled) - p_max = tl.where(p_max <= 1e-9, 1e-9, p_max) - scale_p = FP8_MAX / p_max - descale_p = p_max / FP8_MAX + p_amax = tl.max(tl.abs(p_drop_scaled)) + p_amax = tl.where(p_amax <= 1e-9, 1e-9, p_amax) + scale_p = FP8_MAX / p_amax + descale_p = p_amax / FP8_MAX # NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do) @@ -267,10 +267,10 @@ def _bwd_kernel_one_col_block( # compute dv if IS_FP8: # compute descale_p - p_max = tl.max(p) - p_max = tl.where(p_max <= 1e-9, 1e-9, p_max) - scale_p = FP8_MAX / p_max - descale_p = p_max / FP8_MAX + p_amax = tl.max(tl.abs(p)) + p_amax = tl.where(p_amax <= 1e-9, 1e-9, p_amax) + scale_p = FP8_MAX / p_amax + descale_p = p_amax / FP8_MAX # NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do) @@ -292,15 +292,13 @@ def _bwd_kernel_one_col_block( dscores_scaled = (p * (dp - delta_i[:, None])) ds = dscores_scaled * sm_scale ds = tl.where(p_mask, ds, 0.0) - - print("ds:", ds) # compute descale_ds if IS_FP8: - ds_max = tl.max(ds) - ds_max = tl.where(ds_max <= 1e-9, 1e-9, ds_max) - scale_ds = FP8_MAX / ds_max - descale_ds = ds_max / FP8_MAX + ds_amax = tl.max(tl.abs(ds)) # NOTE: ds can be negative so if we get a negative max value. It will screw things up. + ds_amax = tl.where(ds_amax <= 1e-9, 1e-9, ds_amax) + scale_ds = FP8_MAX / ds_amax + descale_ds = ds_amax / FP8_MAX else: scale_ds, descale_ds = 1.0, 1.0 diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 4380f8f80..5b292bbce 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -180,10 +180,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri if IS_FP8: # compute descale_p - p_max = tl.max(p) - p_max = tl.where(p_max <= 1e-9, 1e-9, p_max) - scale_p = FP8_MAX / p_max - descale_p = p_max / FP8_MAX + p_amax = tl.max(tl.abs(p)) + p_amax = tl.where(p_amax <= 1e-9, 1e-9, p_amax) + scale_p = FP8_MAX / p_amax + descale_p = p_amax / FP8_MAX # NOTE: put p into fp8 range and multiple by do which is already in the fp8 range by the user acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 3f7382d41..51a84d613 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -739,41 +739,41 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): @pytest.mark.parametrize( "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ - # (1, 1, 1, 1, 1, 1), + (1, 1, 1, 1, 1, 1), (1, 1, 1, 2, 2, 16), - # (1, 1, 1, 2, 4, 16), - # (1, 2, 2, 2, 4, 16), - # (1, 4, 1, 2, 4, 16), - # (1, 4, 2, 2, 4, 16), - # (1, 1, 1, 4, 2, 16), - # (1, 1, 1, 4, 4, 16), - # (1, 2, 2, 4, 4, 16), - # (2, 1, 1, 4, 4, 16), - # (2, 2, 2, 4, 4, 16), - # (1, 1, 1, 128, 64, 16), - # (2, 2, 2, 2, 128, 1), - # (2, 3, 3, 2, 128, 16), - # (3, 2, 2, 256, 512, 16), - # (3, 3, 3, 128, 128, 64), - # (2, 4, 4, 1024, 1024, 64), - # (4, 6, 6, 108, 256, 224), - # (4, 8, 8, 2048, 2048, 128), - # (4, 16, 16, 4096, 4096, 64), - # (2, 4, 4, 8192, 8192, 32), - # # fa configs - # (4, 6, 1, 113, 203, 256), - # (4, 6, 1, 128, 217, 256), - # (4, 6, 2, 113, 211, 128), - # (4, 6, 2, 108, 256, 128), - # (4, 6, 1, 256, 512, 64), - # (4, 6, 1, 512, 256, 64), - # (4, 6, 2, 1024, 1024, 32), - # (4, 6, 2, 1023, 1024, 32), - # (4, 6, 6, 1024, 1023, 32), - # (4, 6, 6, 2048, 2048, 32), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), ], ) -@pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize('DEBUG_INPUT', [False]) @pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") @@ -939,7 +939,7 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ], ) @pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.25]) +@pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize('DEBUG_INPUT', [False]) @pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT): @@ -1017,7 +1017,6 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop descale_q = q_maxes / type_max descale_k = k_maxes / type_max descale_v = v_maxes / type_max - descale_p = torch.full_like(descale_q, 1.0 / type_max, dtype=torch.float32, device=q.device) descale_do = do_maxes / type_max # scale tensors to fp8 range @@ -1068,7 +1067,7 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop descale_q=descale_q, descale_k=descale_k, descale_v=descale_v, - descale_p=descale_p, + descale_p=None, descale_do=descale_do )