From 5d5fe819b8d848b1cddcb4c10eba453356c901f4 Mon Sep 17 00:00:00 2001 From: hXl3s Date: Tue, 20 Aug 2024 19:01:37 +0200 Subject: [PATCH] feat(pytorch): Allow TransformerLayer and MultiheadAttention to accept sequence length parameters (#1066) * Added ability for seqlen for transformer and mha layer Signed-off-by: Lukasz Pierscieniewski * Documentation for new parameters Signed-off-by: Lukasz Pierscieniewski * Add tests for THD layout, assert for THD layout with KV-Cache Signed-off-by: Lukasz Pierscieniewski * Fixed tests Signed-off-by: Lukasz Pierscieniewski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move THD logic in shape calculation, add missing optional in params Signed-off-by: Lukasz Pierscieniewski * Skip the THD test on GPUs older than Ampere Signed-off-by: Przemek Tredak --------- Signed-off-by: Lukasz Pierscieniewski Signed-off-by: Przemek Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Przemek Tredak --- tests/pytorch/test_numerics.py | 47 ++++++++++++++++++- transformer_engine/pytorch/attention.py | 44 +++++++++++++---- .../pytorch/module/layernorm_mlp.py | 3 +- transformer_engine/pytorch/transformer.py | 20 ++++++++ 4 files changed, 102 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a219f24674..a2023f539a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -34,11 +34,13 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 torch.manual_seed(seed) @@ -1548,8 +1550,29 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): attn_input_format="bshd", ) - for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): - assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" + torch.manual_seed(0) + block_thd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="thd", + self_attn_mask_type="padding_causal", + ) + + for (n1, p1), (n2, p2), (n3, p3) in zip( + block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters() + ): + assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( (config.seq_len, bs, config.hidden_size), @@ -1559,6 +1582,8 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ) x_bshd = x_sbhd.transpose(0, 1).contiguous() + x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -1576,6 +1601,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): y_sbhd.transpose(0, 1).contiguous(), ) + # THD is not supported in float32 and on GPUs older than Ampere, skip the test here + if dtype != torch.float32 and sm_80plus: + # To make sure forward is also identical (just in case some module decides + # to act fancy) + torch.manual_seed(0) + y_thd = block_thd( + x_thd, + cu_seqlens_q=x_thd_cumsum, + cu_seqlens_kv=x_thd_cumsum, + max_seqlen_q=config.seq_len, + max_seqlen_kv=config.seq_len, + ) + + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + ) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 904dbbde01..71bc15fdad 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7048,6 +7048,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> Tuple[Union[torch.Tensor, None], ...]: """ @@ -7113,6 +7117,18 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. """ @@ -7139,6 +7155,9 @@ def forward( # ================================================= if inference_params and self.layer_number is not None: + assert ( + self.qkv_format != "thd" + ), "qkv_format == thd is not supported for an inference with KV-cache!" if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size @@ -7221,13 +7240,18 @@ def forward( dim=split_dim, ) - # query: -> [sq, b, np, hn] - # key, value: -> [sq, b, ng, hn] - query_layer, key_layer, value_layer = ( - x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) - for x in (query_layer, key_layer, value_layer) - ) - + if self.qkv_format == "thd": + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) + else: + # query: -> [sq, b, np, hn] + # key, value: -> [sq, b, ng, hn] + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) elif self.attention_type == "cross": # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( @@ -7341,8 +7365,10 @@ def forward( key_layer, value_layer, qkv_format=self.qkv_format, - cu_seqlens_q=None, - cu_seqlens_kv=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index be6df21322..dc9bef645f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -13,6 +13,7 @@ from .base import ( get_workspace, + _ub_communicators, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -1297,7 +1298,7 @@ def __init__( self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == "gelu" - and not get_ub("fc1_fprop").is_atomic_gemm() + and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) if tp_group is None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index f026da23ef..4cbee3d628 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -529,6 +529,10 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, ) -> torch.Tensor: """ @@ -604,6 +608,18 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + cu_seqlens_q: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = `None` + Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` + and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + max_seqlen_q: Optional[int], default = `None` + Maximum sequence length in `query_layer`. + Calculated from `cu_seqlens_q` if not provided. + max_seqlen_kv: Optional[int], default = `None` + Maximum sequence length in `key_layer` and `value_layer`. + Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None @@ -664,6 +680,10 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, )