Skip to content

Commit

Permalink
[PyTorch] Make FP8 MHA work with RoPE when CP is on (#1297)
Browse files Browse the repository at this point in the history
* Let fp8 mha work with rope when cp is on

Signed-off-by: Xin Yao <[email protected]>

* fix and update ut

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 authored Nov 4, 2024
1 parent a6a9141 commit c42beef
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 49 deletions.
24 changes: 20 additions & 4 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,24 @@
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.common.recipe import DelayedScaling

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}


def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p"
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
):
"""Test DotProductAttention module with context parallelism"""

# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
Expand Down Expand Up @@ -72,7 +80,7 @@ def run_dpa_with_cp(
cp_comm_sub_groups.append(sub_group)

if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)

# instantiate core attn module
core_attn = DotProductAttention(
Expand Down Expand Up @@ -201,7 +209,11 @@ def run_dpa_with_cp(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
)
out.backward(dout)
if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
out.backward(dout_fp8)
else:
out.backward(dout)

# run core_attn wit CP
q_, k_, v_, dout_, *rest = [
Expand Down Expand Up @@ -269,7 +281,11 @@ def run_dpa_with_cp(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
)
out_.backward(dout_)
if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
out_.backward(dout_fp8_)
else:
out_.backward(dout_)

for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x))
Expand Down
2 changes: 0 additions & 2 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,8 +1356,6 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
config = model_configs_fp8_vs_f16[model]

if _flash_attn_3_is_installed and not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
Expand Down
6 changes: 5 additions & 1 deletion tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("fp8_mha", [False, True])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
Expand Down Expand Up @@ -153,6 +154,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!")

subprocess.run(
get_bash_arguments(
Expand All @@ -162,6 +165,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha,
),
check=True,
)
95 changes: 53 additions & 42 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,17 +1729,20 @@ def forward(
fused_attn_qkv_dtype = None
fused_attn_backend = None
amax_per_step = None
qkv_dtype = q.dtype
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA!"
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
Expand Down Expand Up @@ -1778,7 +1781,7 @@ def forward(
)
if not fp8:
q_f16 = q
elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16 = q
q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)

Expand Down Expand Up @@ -1880,11 +1883,7 @@ def forward(
batch_p2p_comm,
)

if (
not fp8
or fp8_meta["recipe"].fp8_mha
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
):
if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
kv_inputs[i % 2] = p2p_comm_buffers[i]
else:
# KV exchange is in BF16/FP16, cast received KV in each step
Expand Down Expand Up @@ -2436,18 +2435,18 @@ def forward(
fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]

out_fp8 = None
out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_f16 = out.to(qkv_dtype)
if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)

if fp8 and fp8_meta["recipe"].fp8_mha:
if fp8 and is_output_fp8:
out_ret = Float8Tensor(
data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
dtype=qkv_dtype,
)
else:
out_ret = out_f16
Expand All @@ -2456,7 +2455,7 @@ def forward(
q_save, kv_save, out_save = q, kv, out_fp8
fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
elif fp8 and fp8_meta["recipe"].fp8_mha:
elif fp8 and is_input_fp8:
q_fp8 = Float8Tensor(
data=q,
fp8_meta=fp8_meta,
Expand Down Expand Up @@ -2513,6 +2512,8 @@ def forward(
ctx.use_fused_attention = use_fused_attention
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
return out_ret

@staticmethod
Expand Down Expand Up @@ -2595,7 +2596,7 @@ def backward(ctx, dout):
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
dkv_fp8_ = torch.empty_like(dkv_fp8)
if ctx.fp8_meta["recipe"].fp8_mha:
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
dout = dout._data
Expand All @@ -2617,7 +2618,7 @@ def backward(ctx, dout):
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
if ctx.fp8_meta is not None and ctx.is_input_fp8:
q, kv = [x.from_float8(x.dtype) for x in [q, kv]]
if cp_size_a2a == 1:
dout = dout.from_float8(dout_dtype)
Expand Down Expand Up @@ -2653,7 +2654,7 @@ def backward(ctx, dout):
ctx.cp_stream,
True,
)
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8:
dout = cast_from_fp8(
dout,
None,
Expand Down Expand Up @@ -3260,7 +3261,7 @@ def backward(ctx, dout):
dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
dkv = dkv_

if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
if ctx.fp8 and ctx.is_input_fp8:
dq, dkv = [
cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
for x in [dq, dkv]
Expand All @@ -3283,7 +3284,7 @@ def backward(ctx, dout):
elif ctx.qkv_format == "sbhd":
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
if ctx.fp8 and ctx.is_input_fp8:
dq, dk, dv = [
Float8Tensor(
data=x,
Expand Down Expand Up @@ -3852,19 +3853,22 @@ def forward(
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"

qkv_dtype = q.dtype
fused_attn_backend = None
fused_attn_qkv_dtype = None
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA!"
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
Expand Down Expand Up @@ -3900,7 +3904,7 @@ def forward(
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
)

if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_f16, k_f16, v_f16 = q, k, v
q, k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
Expand Down Expand Up @@ -3965,14 +3969,14 @@ def forward(
out = out.view(-1, batch_size, *out.shape[-2:])

if fp8:
if fp8_meta["recipe"].fp8_mha:
if is_output_fp8:
out_fp8 = Float8Tensor(
data=out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
dtype=qkv_dtype,
)
out = out_fp8._data
out_ret = out_fp8
Expand All @@ -3991,7 +3995,7 @@ def forward(
if fp8:
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, k_save, v_save, out_save = q, k, v, out
elif fp8_meta["recipe"].fp8_mha:
elif is_input_fp8:
q_fp8, k_fp8, v_fp8 = [
Float8Tensor(
data=x,
Expand Down Expand Up @@ -4043,6 +4047,8 @@ def forward(
ctx.use_fused_attention = use_fused_attention
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
return out_ret

@staticmethod
Expand All @@ -4064,14 +4070,15 @@ def backward(ctx, dout):
fused_attn_backend = None
fused_attn_dqkv_dtype = None
fused_attn_qkv_dtype = None
dout_dtype = dout.dtype
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"]
if ctx.fp8_meta["recipe"].fp8_mha:
if ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
dout_fp8 = dout
Expand All @@ -4097,7 +4104,7 @@ def backward(ctx, dout):
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
if ctx.fp8_meta is not None and ctx.is_output_fp8:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]]
if ctx.use_fused_attention:
Expand Down Expand Up @@ -4194,15 +4201,15 @@ def backward(ctx, dout):
dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]

if ctx.fp8:
if ctx.fp8_meta["recipe"].fp8_mha:
if ctx.is_input_fp8:
dq, dk, dv = [
Float8Tensor(
data=x,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=dout_fp8.dtype,
dtype=dout_dtype,
)
for x in [dq, dk, dv]
]
Expand All @@ -4213,7 +4220,7 @@ def backward(ctx, dout):
ctx.fp8_meta["scaling_bwd"],
META_DQKV,
fp8_dtype_backward,
TE_DType[dout_f16.dtype],
TE_DType[dout_dtype],
)
for x in [dq, dk, dv]
]
Expand Down Expand Up @@ -5434,11 +5441,12 @@ def convert_to_torch_float8(tensor, dtype):
)
return out

if fp8_meta["recipe"].fp8_mha:
assert all(
isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA."
# "fp8_mha" decides outputs in fp8, while inputs are inferred from
# the real dtype
assert isinstance(key_layer, query_layer.__class__) and isinstance(
value_layer, query_layer.__class__
), "q, k, and v must have the same type."
if isinstance(query_layer, Float8Tensor):
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
else:
query_layer, key_layer, value_layer = (
Expand Down Expand Up @@ -5580,6 +5588,7 @@ def forward(
deterministic,
):
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8:
Expand Down Expand Up @@ -5970,6 +5979,7 @@ def forward(
deterministic,
):
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8:
Expand Down Expand Up @@ -6424,6 +6434,7 @@ def forward(
deterministic,
):
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8:
Expand Down

0 comments on commit c42beef

Please sign in to comment.