diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 51a84d613..f47cff00e 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -824,17 +824,17 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, v_max = torch.maximum(v.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) do_max = torch.maximum(do.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) - # compute descale values - descale_q = q_max / type_max - descale_k = k_max / type_max - descale_v = v_max / type_max - descale_do = do_max / type_max + # compute scaling and descaling factors + scale_q, descale_q = type_max/ q_max, q_max / type_max + scale_k, descale_k = type_max/ k_max, k_max / type_max + scale_v, descale_v = type_max/ v_max, v_max / type_max + scale_do, descale_do = type_max/ do_max, do_max / type_max # scale values to fp8 range - q_fp8 = (q.clone() * type_max/ q_max).to(torch.float8_e4m3fnuz).requires_grad_() - k_fp8 = (k.clone() * type_max/ k_max).to(torch.float8_e4m3fnuz).requires_grad_() - v_fp8 = (v.clone() * type_max/ v_max).to(torch.float8_e4m3fnuz).requires_grad_() - do_fp8 = (do.clone() * type_max/ do_max).to(torch.float8_e4m3fnuz) + q_fp8 = (q.clone() * scale_q).to(torch.float8_e4m3fnuz).requires_grad_() + k_fp8 = (k.clone() * scale_k).to(torch.float8_e4m3fnuz).requires_grad_() + v_fp8 = (v.clone() * scale_v).to(torch.float8_e4m3fnuz).requires_grad_() + do_fp8 = (do.clone() * scale_do).to(torch.float8_e4m3fnuz) # fp8 forward pass out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func( @@ -985,14 +985,13 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop # --- FP8 --- # ---------------------------------------------------------------- type_max = torch.finfo(torch.float8_e4m3fnuz).max - batch = len(metadata.cu_seqlens_q) - 1 # get maxes q_maxes = [] k_maxes = [] v_maxes = [] do_maxes = [] - for i in range(batch): + for i in range(Z): q_start = metadata.cu_seqlens_q[i] q_end = metadata.cu_seqlens_q[i + 1] k_start = metadata.cu_seqlens_k[i] @@ -1013,29 +1012,23 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop v_maxes = torch.stack(v_maxes) do_maxes = torch.stack(do_maxes) - # compute descale values - descale_q = q_maxes / type_max - descale_k = k_maxes / type_max - descale_v = v_maxes / type_max - descale_do = do_maxes / type_max + # compute scaling and descaling factors + scale_q, descale_q = type_max/ q_maxes, q_maxes / type_max + scale_k, descale_k = type_max/ k_maxes, k_maxes / type_max + scale_v, descale_v = type_max/ v_maxes, v_maxes / type_max + scale_do, descale_do = type_max/ do_maxes, do_maxes / type_max # scale tensors to fp8 range - q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fnuz) - k_fp8 = torch.empty_like(k, dtype=torch.float8_e4m3fnuz) - v_fp8 = torch.empty_like(v, dtype=torch.float8_e4m3fnuz) - do_fp8 = torch.empty_like(do, dtype=torch.float8_e4m3fnuz) - for i in range(batch): + q_fp8 = torch.zeros_like(q, dtype=torch.float8_e4m3fnuz) + k_fp8 = torch.zeros_like(k, dtype=torch.float8_e4m3fnuz) + v_fp8 = torch.zeros_like(v, dtype=torch.float8_e4m3fnuz) + do_fp8 = torch.zeros_like(do, dtype=torch.float8_e4m3fnuz) + for i in range(Z): q_start = metadata.cu_seqlens_q[i] q_end = metadata.cu_seqlens_q[i + 1] k_start = metadata.cu_seqlens_k[i] k_end = metadata.cu_seqlens_k[i + 1] - # shape [heads_q, 1], broadcast to [1, heads_q, 1] - q_scale = (type_max / q_maxes[i]).unsqueeze(0) # => [1, HQ, 1] - k_scale = (type_max / k_maxes[i]).unsqueeze(0) # => [1, HK, 1] - v_scale = (type_max / v_maxes[i]).unsqueeze(0) # => [1, HK, 1] - do_scale = (type_max / do_maxes[i]).unsqueeze(0) # => [1, HQ, 1] - # q, k, v are [L, heads, dim] slices q_slice = q[q_start:q_end] # [seq_len_i, HQ, dim] k_slice = k[k_start:k_end] # [seq_len_i, HK, dim] @@ -1043,10 +1036,10 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop do_slice = do[q_start:q_end] # [seq_len_i, HQ, dim] # Convert them to FP8 - q_fp8[q_start:q_end] = (q_slice * q_scale).to(torch.float8_e4m3fnuz) - k_fp8[k_start:k_end] = (k_slice * k_scale).to(torch.float8_e4m3fnuz) - v_fp8[k_start:k_end] = (v_slice * v_scale).to(torch.float8_e4m3fnuz) - do_fp8[q_start:q_end] = (do_slice * do_scale).to(torch.float8_e4m3fnuz) + q_fp8[q_start:q_end] = (q_slice * scale_q[i]).to(torch.float8_e4m3fnuz) + k_fp8[k_start:k_end] = (k_slice * scale_k[i]).to(torch.float8_e4m3fnuz) + v_fp8[k_start:k_end] = (v_slice * scale_v[i]).to(torch.float8_e4m3fnuz) + do_fp8[q_start:q_end] = (do_slice * scale_do[i]).to(torch.float8_e4m3fnuz) # launch kernel in fp8 out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_func(