Skip to content

Commit

Permalink
clean up tests a bit more
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 4, 2025
1 parent 38b1b46 commit 7d0277c
Showing 1 changed file with 24 additions and 31 deletions.
55 changes: 24 additions & 31 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,17 +824,17 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
v_max = torch.maximum(v.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1)
do_max = torch.maximum(do.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1)

# compute descale values
descale_q = q_max / type_max
descale_k = k_max / type_max
descale_v = v_max / type_max
descale_do = do_max / type_max
# compute scaling and descaling factors
scale_q, descale_q = type_max/ q_max, q_max / type_max
scale_k, descale_k = type_max/ k_max, k_max / type_max
scale_v, descale_v = type_max/ v_max, v_max / type_max
scale_do, descale_do = type_max/ do_max, do_max / type_max

# scale values to fp8 range
q_fp8 = (q.clone() * type_max/ q_max).to(torch.float8_e4m3fnuz).requires_grad_()
k_fp8 = (k.clone() * type_max/ k_max).to(torch.float8_e4m3fnuz).requires_grad_()
v_fp8 = (v.clone() * type_max/ v_max).to(torch.float8_e4m3fnuz).requires_grad_()
do_fp8 = (do.clone() * type_max/ do_max).to(torch.float8_e4m3fnuz)
q_fp8 = (q.clone() * scale_q).to(torch.float8_e4m3fnuz).requires_grad_()
k_fp8 = (k.clone() * scale_k).to(torch.float8_e4m3fnuz).requires_grad_()
v_fp8 = (v.clone() * scale_v).to(torch.float8_e4m3fnuz).requires_grad_()
do_fp8 = (do.clone() * scale_do).to(torch.float8_e4m3fnuz)

# fp8 forward pass
out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func(
Expand Down Expand Up @@ -985,14 +985,13 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
# --- FP8 ---
# ----------------------------------------------------------------
type_max = torch.finfo(torch.float8_e4m3fnuz).max
batch = len(metadata.cu_seqlens_q) - 1

# get maxes
q_maxes = []
k_maxes = []
v_maxes = []
do_maxes = []
for i in range(batch):
for i in range(Z):
q_start = metadata.cu_seqlens_q[i]
q_end = metadata.cu_seqlens_q[i + 1]
k_start = metadata.cu_seqlens_k[i]
Expand All @@ -1013,40 +1012,34 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
v_maxes = torch.stack(v_maxes)
do_maxes = torch.stack(do_maxes)

# compute descale values
descale_q = q_maxes / type_max
descale_k = k_maxes / type_max
descale_v = v_maxes / type_max
descale_do = do_maxes / type_max
# compute scaling and descaling factors
scale_q, descale_q = type_max/ q_maxes, q_maxes / type_max
scale_k, descale_k = type_max/ k_maxes, k_maxes / type_max
scale_v, descale_v = type_max/ v_maxes, v_maxes / type_max
scale_do, descale_do = type_max/ do_maxes, do_maxes / type_max

# scale tensors to fp8 range
q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fnuz)
k_fp8 = torch.empty_like(k, dtype=torch.float8_e4m3fnuz)
v_fp8 = torch.empty_like(v, dtype=torch.float8_e4m3fnuz)
do_fp8 = torch.empty_like(do, dtype=torch.float8_e4m3fnuz)
for i in range(batch):
q_fp8 = torch.zeros_like(q, dtype=torch.float8_e4m3fnuz)
k_fp8 = torch.zeros_like(k, dtype=torch.float8_e4m3fnuz)
v_fp8 = torch.zeros_like(v, dtype=torch.float8_e4m3fnuz)
do_fp8 = torch.zeros_like(do, dtype=torch.float8_e4m3fnuz)
for i in range(Z):
q_start = metadata.cu_seqlens_q[i]
q_end = metadata.cu_seqlens_q[i + 1]
k_start = metadata.cu_seqlens_k[i]
k_end = metadata.cu_seqlens_k[i + 1]

# shape [heads_q, 1], broadcast to [1, heads_q, 1]
q_scale = (type_max / q_maxes[i]).unsqueeze(0) # => [1, HQ, 1]
k_scale = (type_max / k_maxes[i]).unsqueeze(0) # => [1, HK, 1]
v_scale = (type_max / v_maxes[i]).unsqueeze(0) # => [1, HK, 1]
do_scale = (type_max / do_maxes[i]).unsqueeze(0) # => [1, HQ, 1]

# q, k, v are [L, heads, dim] slices
q_slice = q[q_start:q_end] # [seq_len_i, HQ, dim]
k_slice = k[k_start:k_end] # [seq_len_i, HK, dim]
v_slice = v[k_start:k_end] # [seq_len_i, HK, dim]
do_slice = do[q_start:q_end] # [seq_len_i, HQ, dim]

# Convert them to FP8
q_fp8[q_start:q_end] = (q_slice * q_scale).to(torch.float8_e4m3fnuz)
k_fp8[k_start:k_end] = (k_slice * k_scale).to(torch.float8_e4m3fnuz)
v_fp8[k_start:k_end] = (v_slice * v_scale).to(torch.float8_e4m3fnuz)
do_fp8[q_start:q_end] = (do_slice * do_scale).to(torch.float8_e4m3fnuz)
q_fp8[q_start:q_end] = (q_slice * scale_q[i]).to(torch.float8_e4m3fnuz)
k_fp8[k_start:k_end] = (k_slice * scale_k[i]).to(torch.float8_e4m3fnuz)
v_fp8[k_start:k_end] = (v_slice * scale_v[i]).to(torch.float8_e4m3fnuz)
do_fp8[q_start:q_end] = (do_slice * scale_do[i]).to(torch.float8_e4m3fnuz)

# launch kernel in fp8
out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_func(
Expand Down

0 comments on commit 7d0277c

Please sign in to comment.