diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index b4eeefa70b..4762cccee6 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,8 +17,8 @@ jobs: uses: actions/checkout@v3 - name: 'Install dependencies' run: | - pip install sphinx==5.1.1 sphinx_rtd_theme==1.0.0 nbsphinx==0.8.10 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==2.15.7 - pip install breathe==4.34.0 sphinx-autoapi==2.0.1 + pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 + pip install breathe==4.35.0 sphinx-autoapi==3.3.2 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) - name: 'Build docs' diff --git a/.gitignore b/.gitignore index 6890911c14..9b61454e21 100644 --- a/.gitignore +++ b/.gitignore @@ -22,9 +22,7 @@ __pycache__ .hypothesis .devcontainer.json tests/cpp/build/ -docs/_build .ipynb_checkpoints -docs/doxygen *.log CMakeFiles/CMakeSystem.cmake sdist/ @@ -40,3 +38,4 @@ dist/ downloads/ .pytest_cache/ compile_commands.json +.nfs diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 2533f5e5c1..936021bfed 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b +Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 9152229d2f..e7924d8a21 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -11,7 +11,6 @@ from .utils import ( all_files_in_dir, cuda_archs, - cuda_path, cuda_version, ) @@ -29,9 +28,6 @@ def setup_pytorch_extension( sources = [ csrc_source_files / "common.cu", csrc_source_files / "ts_fp8_op.cpp", - csrc_source_files / "userbuffers" / "ipcsocket.cc", - csrc_source_files / "userbuffers" / "userbuffers.cu", - csrc_source_files / "userbuffers" / "userbuffers-host.cpp", ] + all_files_in_dir(extensions_dir) # Header files @@ -85,19 +81,14 @@ def setup_pytorch_extension( continue # Already handled nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) - # Libraries - library_dirs = [] - libraries = [] - if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))): + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): assert ( os.getenv("MPI_HOME") is not None - ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" - mpi_home = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_home / "include") + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_path / "include") cxx_flags.append("-DNVTE_UB_WITH_MPI") nvcc_flags.append("-DNVTE_UB_WITH_MPI") - library_dirs.append(mpi_home / "lib") - libraries.append("mpi") # Construct PyTorch CUDA extension sources = [str(path) for path in sources] @@ -112,6 +103,4 @@ def setup_pytorch_extension( "cxx": cxx_flags, "nvcc": nvcc_flags, }, - libraries=[str(lib) for lib in libraries], - library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000000..409af2d74e --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,3 @@ +_build +doxygen +sphinx_rtd_theme \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index d4bb2cbb9e..800eeea78a 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -16,5 +16,10 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +%: Makefile sphinx_rtd_theme + PYTHONPATH=sphinx_rtd_theme:$(PYTHONPATH) $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +# Patch Sphinx RTD theme 3.0.1 to add version selector in sidebar +sphinx_rtd_theme: + git clone --depth=1 -b 3.0.1 --single-branch https://github.com/readthedocs/sphinx_rtd_theme.git + bash -c "cd sphinx_rtd_theme; git apply ../version_select.patch" diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index b097f14475..ba4e7db352 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -51,3 +51,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute .. autoapifunction:: transformer_engine.pytorch.moe_unpermute + +.. autoapifunction:: transformer_engine.pytorch.initialize_ub + +.. autoapifunction:: transformer_engine.pytorch.destroy_ub diff --git a/docs/conf.py b/docs/conf.py index 7a50ce76cf..7d2d4ea7b9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -2,38 +2,30 @@ # # See LICENSE for license information. +import datetime import os -import sys -import sphinx_rtd_theme -from sphinx.ext.autodoc.mock import mock -from sphinx.ext.autodoc import between, ClassDocumenter, AttributeDocumenter -from sphinx.util import inspect -from builtins import str -from enum import Enum -import re +import pathlib import subprocess -from pathlib import Path -from datetime import date - -te_path = os.path.dirname(os.path.realpath(__file__)) +from builtins import str -with open(te_path + "/../build_tools/VERSION.txt", "r") as f: - te_version = f.readline().strip() +# Basic project info +project = "Transformer Engine" +author = "NVIDIA CORPORATION & AFFILIATES" +# Copyright statement release_year = 2022 - -current_year = date.today().year +current_year = datetime.date.today().year if current_year == release_year: copyright_year = release_year else: copyright_year = str(release_year) + "-" + str(current_year) +copyright = f"{copyright_year}, NVIDIA CORPORATION & AFFILIATES. All rights reserved." -project = "Transformer Engine" -copyright = "{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.".format(copyright_year) -author = "NVIDIA CORPORATION & AFFILIATES" +# Transformer Engine root directory +root_path = pathlib.Path(__file__).resolve().parent.parent +# Git hash git_sha = os.getenv("GIT_SHA") - if not git_sha: try: git_sha = ( @@ -44,31 +36,16 @@ ) except: git_sha = "0000000" - git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha -if "dev" in te_version: - version = str(te_version + "-" + git_sha) +# Version +with open(root_path / "build_tools" / "VERSION.txt", "r") as f: + _raw_version = f.readline().strip() +if "dev" in _raw_version: + version = str(_raw_version + "-" + git_sha) else: - version = str(te_version) -release = te_version - -# hack: version is used for html creation, so put the version picker -# link here as well: -option_on = " selected" -option_off = "" -release_opt = option_on -option_nr = 0 -version = ( - version - + """
-Version select: """.format( - option_nr, release_opt - ) -) + version = str(_raw_version) +release = _raw_version # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -92,12 +69,10 @@ pygments_style = "sphinx" - # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_static_path = ["_static"] html_show_sphinx = False @@ -106,7 +81,12 @@ "css/nvidia_footer.css", ] -html_theme_options = {"display_version": True, "collapse_navigation": False, "logo_only": False} +html_theme_options = { + "collapse_navigation": False, + "logo_only": False, + "version_selector": False, + "language_selector": False, +} napoleon_custom_sections = [ ("Parallelism parameters", "params_style"), @@ -116,8 +96,8 @@ ("FP8-related parameters", "params_style"), ] -breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")} +breathe_projects = {"TransformerEngine": root_path / "docs" / "doxygen" / "xml"} breathe_default_project = "TransformerEngine" autoapi_generate_api_docs = False -autoapi_dirs = ["../transformer_engine"] +autoapi_dirs = [root_path / "transformer_engine"] diff --git a/docs/version_select.patch b/docs/version_select.patch new file mode 100644 index 0000000000..75f29fff81 --- /dev/null +++ b/docs/version_select.patch @@ -0,0 +1,21 @@ +diff --git a/sphinx_rtd_theme/layout.html b/sphinx_rtd_theme/layout.html +index e6a38b1..579eaec 100644 +--- a/sphinx_rtd_theme/layout.html ++++ b/sphinx_rtd_theme/layout.html +@@ -124,6 +124,16 @@ + {%- endif %} + + ++ {# Show TE version and version selector #} ++
++ {{ version }} ++
++ Version select: ++
++ + {%- if READTHEDOCS or DEBUG %} + {%- if theme_version_selector or theme_language_selector %} +
diff --git a/setup.py b/setup.py index 512defa619..3bb2fe6b95 100644 --- a/setup.py +++ b/setup.py @@ -57,13 +57,20 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" + cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + # Project directory root root_path = Path(__file__).resolve().parent return CMakeExtension( name="transformer_engine", cmake_path=root_path / Path("transformer_engine/common"), - cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())], + cmake_flags=cmake_flags, ) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 6991d83d4c..6b38cfc751 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -3,9 +3,8 @@ # See LICENSE for license information. from contextlib import nullcontext -import functools -import operator from typing import Callable, List, Sequence, Union +import os import jax import jax.numpy as jnp @@ -14,12 +13,17 @@ from jax import jit, value_and_grad from flax import linen as nn -from utils import assert_allclose +from utils import assert_allclose, assert_tree_like_allclose from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu +from transformer_engine.jax.cpp_extensions.transpose import ( + _jax_transpose, + _jax_cast_transpose, +) +from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 from transformer_engine.jax import cpp_extensions as tex @@ -746,3 +750,102 @@ def ref_func(x, y, gamma, beta, zero_centered_gamma): assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE) if beta is not None: assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE) + + +@pytest.mark.parametrize( + "in_dtype", + [ + pytest.param(jnp.float32, id="input_float32"), + pytest.param(jnp.float16, id="input_float16"), + pytest.param(jnp.bfloat16, id="input_bfloat16"), + ], +) +@pytest.mark.parametrize( + "input_shape, transpose_axis", + [ + pytest.param((16, 16), 1, id="(16, 16)-1"), + pytest.param((256, 128), 1, id="(256, 128)-1"), + pytest.param((128, 512), 1, id="(128, 512)-1"), + pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"), + pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"), + pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"), + ], +) +class TestTranspose: + def test_transpose(self, in_dtype, input_shape, transpose_axis): + key = jax.random.PRNGKey(0) + input_tensor = jax.random.uniform(key, input_shape, in_dtype) + static_axis_boundary = -1 + jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis) + os.environ["NVTE_JAX_WITH_FFI"] = "0" + noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis) + os.environ["NVTE_JAX_WITH_FFI"] = "1" + ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis) + assert_allclose(jax_output, noffi_output) + assert_allclose(noffi_output, ffi_output) + + @pytest.mark.parametrize( + "out_dtype", + [ + pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"), + pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"), + ], + ) + def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype): + amax = jnp.zeros(1, jnp.float32) + scale = jnp.ones(1, jnp.float32) + scale_inv = jnp.ones(1, jnp.float32) + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + static_axis_boundary = -1 + jax_output = _jax_cast_transpose( + input, scale, amax, out_dtype, static_axis_boundary, transpose_axis + ) + os.environ["NVTE_JAX_WITH_FFI"] = "0" + noffi_output = tex.cast_transpose( + input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis + ) + os.environ["NVTE_JAX_WITH_FFI"] = "1" + ffi_output = tex.cast_transpose( + input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis + ) + assert_tree_like_allclose(jax_output, ffi_output) + assert_tree_like_allclose(noffi_output, ffi_output) + + +@pytest.mark.skipif(not is_fp8_supported, reason=reason) +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((256, 128), id="(256, 128)"), + pytest.param((128, 512, 8), id="(128, 512, 8)"), + ], +) +@pytest.mark.parametrize( + "in_dtype", + [ + pytest.param(jnp.float32, id="input_float32"), + pytest.param(jnp.float16, id="input_float16"), + pytest.param(jnp.bfloat16, id="input_bfloat16"), + ], +) +@pytest.mark.parametrize( + "out_dtype", + [ + pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"), + pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"), + ], +) +def test_quantize(input_shape, in_dtype, out_dtype): + amax = jnp.zeros(1, jnp.float32) + scale = jnp.ones(1, jnp.float32) + scale_inv = jnp.ones(1, jnp.float32) + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + jax_output = _jax_cast_fp8(input, scale, amax, out_dtype) + os.environ["NVTE_JAX_WITH_FFI"] = "0" + noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype) + os.environ["NVTE_JAX_WITH_FFI"] = "1" + ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype) + assert_tree_like_allclose(jax_output, ffi_output) + assert_tree_like_allclose(noffi_output, ffi_output) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index c101a89c4c..23a26087d4 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -17,7 +17,13 @@ generate_collectives_count, compare_ops, ) -from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose +from utils import ( + make_causal_mask, + make_self_mask, + assert_tree_like_allclose, + assert_allclose, + print_debug_tensor_stats, +) from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, @@ -31,6 +37,8 @@ inverse_reorder_causal_load_balancing, ) +# We will use the golden reference model from our non distributed attention test fixture. +from test_fused_attn import general_dot_product_attention, make_mask DTYPES = [jnp.float16, jnp.bfloat16] @@ -327,18 +335,27 @@ def ref_func(query, kv, mask): ) -class TestDistributedContexParallelSelfAttn: +class TestDistributedContextParallelSelfAttn: def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): batch, seqlen, heads, hidden = shape + kv_shape = (batch, seqlen, heads // kv_groups, hidden) qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) q = random.normal(qkey, shape, dtype=dtype) k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - mask = None - if attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_causal_mask(batch, seqlen) + def gen_valid(bs, max_seqlen, pad_ratio): + pad_len = int(max_seqlen * pad_ratio) + valid_len = max_seqlen - pad_len + tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) + return tokens, jnp.logical_not(tokens) + + from test_fused_attn import make_mask + + q_idx, _ = gen_valid(batch, seqlen, 0.0) + kv_idx, _ = gen_valid(batch, seqlen, 0.0) + mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type) return q, k, v, mask @@ -382,7 +399,8 @@ def qkv_to_layout(self, q, k, v, qkv_layout): ], ) @pytest.mark.parametrize( - "load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")] + "load_balanced", + [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], ) def test_contex_parallel_self_attn( self, @@ -400,12 +418,12 @@ def test_contex_parallel_self_attn( attn_bias_type = AttnBiasType.NO_BIAS dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 dp_size, cp_size, tp_size = mesh_shape qkv_format = get_qkv_format(qkv_layout) - _, seqlen, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape num_kv_heads = num_head // kv_groups + scaling_factor = 1.0 / np.sqrt(num_head) if not is_fused_attn_kernel_available( dtype, @@ -424,54 +442,69 @@ def test_contex_parallel_self_attn( ): pytest.skip(f"No FusedAttn backend found") + if dp_size > 1 and batch % dp_size != 0: + pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}") + # make sure the mesh even divides cp and tp axis if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") def target_func(q, k, v, mask): - return jnp.mean( - fused_attn( - self.qkv_to_layout(q, k, v, qkv_layout), - bias=None, - mask=mask, - seed=None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - context_parallel_causal_load_balanced=load_balanced, - ), + return fused_attn( + self.qkv_to_layout(q, k, v, qkv_layout), + None, # bias + mask, + None, # seed + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training, + context_parallel_causal_load_balanced=load_balanced, + context_parallel_axis="cp", ).astype(dtype) - def ref_func(q, k, v, mask, kv_groups): - q = jnp.squeeze(q) - k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2)) - v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2)) - output = dot_product_attention( + def ref_func(q, k, v, mask): + output = general_dot_product_attention( q, k, v, bias=None, mask=mask, - deterministic=is_training, + deterministic=not is_training, + scale_factor=scaling_factor, dropout_rate=dropout_prob, dropout_rng=None, dtype=jnp.float32, ) - return jnp.mean(output).astype(dtype) + return output.astype(dtype) + + def grad_func(func, *args, **kwargs): + # Gradient is small, use a gradient multiplier to amplify the gradient + _, max_seq_len, num_heads, _ = data_shape + gradient_multiplier = max_seq_len * num_heads + if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: + gradient_multiplier /= 10 + ret_valid = func(*args, **kwargs) + return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) + diff_argnums = (0, 1, 2) + # Single GPU (reference) - ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4]) - ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups) + ref_func_jit = jax.jit( + jax.value_and_grad( + lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums + ) + ) + ref_fwd, ref_grads = ref_func_jit(q, k, v, mask) # Multi GPU (function under test) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): + with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False): qkv_ps = PartitionSpec( mesh_resource.dp_resource, mesh_resource.cp_resource, @@ -499,7 +532,10 @@ def ref_func(q, k, v, mask, kv_groups): mask_ = jax.device_put(mask, device=mask_sharding) target_func_jit = jax.jit( - jax.value_and_grad(target_func, argnums=[0, 1, 2]), + jax.value_and_grad( + lambda q, k, v, mask: grad_func(target_func, q, k, v, mask), + argnums=diff_argnums, + ), in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), ) @@ -510,37 +546,25 @@ def ref_func(q, k, v, mask, kv_groups): target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) - def _print_diffs(target, ref): - print("min: ", jnp.min(target), jnp.min(ref)) - print("max: ", jnp.max(target), jnp.max(ref)) - print("mean: ", jnp.mean(target), jnp.mean(ref)) - print("median: ", jnp.median(target), jnp.median(ref)) - print("std: ", jnp.std(target), jnp.std(ref)) - print("var: ", jnp.var(target), jnp.var(ref)) - print("max diff: ", jnp.max(jnp.abs(target - ref))) - has_diffs = False - try: - assert_allclose(target_fwd, ref_fwd, dtype=dtype) - except AssertionError as e: - has_diffs = True - print(f"target_fwd v. ref_fwd") - _print_diffs(target_fwd, ref_fwd) + print_debug_tensor_stats("target", target_fwd) + print_debug_tensor_stats("ref", ref_fwd) + print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd)) + assert_allclose(target_fwd, ref_fwd, dtype=dtype) for i in range(len(target_grads)): if ref_grads[i] is None or target_grads[i] is None: # expect both none if one is assert target_grads[i] is None and ref_grads[i] is None else: - try: - assert_allclose(target_grads[i], ref_grads[i]) - except AssertionError as e: - has_diffs = True - print(f"target_grads[{i}] v. ref_grads[{i}]") - _print_diffs(target_grads[i], ref_grads[i]) - - assert has_diffs == False, "has_diffs != False" + print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i]) + print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i]) + print_debug_tensor_stats( + f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i]) + ) + + assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) class TestReorderCausalLoadBalancing: diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 3245bca676..55c09b4562 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -67,7 +67,7 @@ def enable_fused_attn(): _KEY_OF_TRANSPOSE_BS: True, _KEY_OF_NUM_HEADS: 8, _KEY_OF_HIDDEN_DROPOUT: 0, - _KEY_OF_ATTENTION_DROPOUT: 0, + _KEY_OF_ATTENTION_DROPOUT: 0.0, _KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", _KEY_OF_LAYERNORM_TYPE: "layernorm", diff --git a/tests/jax/utils.py b/tests/jax/utils.py index cefda1a2f5..78a6225e1f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -7,6 +7,7 @@ import math import operator from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional +import os import jax import jax.numpy as jnp @@ -30,6 +31,9 @@ ] Initializer = Callable[[PRNGKey, Shape, DType], Array] +# Enables verbose printing of tensor numerics for debug. +NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0))) + def is_devices_enough(required): """ @@ -1466,3 +1470,23 @@ def sync_params_values(dst, src, transformations, sep="/"): synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values) return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst) + + +@functools.partial(jax.jit, static_argnums=[0, 2]) +def print_debug_tensor_stats(prefix, tensor, hist=False): + if NVTE_DEBUG_NUMERICS: + args = [ + jnp.mean(tensor), + jnp.min(tensor), + jnp.max(tensor), + jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1, + jnp.count_nonzero(tensor), + ] + fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}" + + if hist: + h = jnp.histogram(tensor.astype(jnp.float32), bins=10) + args += [h[0], h[1]] + fmt = fmt + "\n {}\n {}" + + jax.debug.print(fmt, *args) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 5ba70ccbdd..b00b8cc042 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -22,6 +22,7 @@ from transformer_engine.common.recipe import Format from transformer_engine.pytorch.fp8 import _default_sf_compute +warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) @@ -32,8 +33,8 @@ } nvte_comm_types = { - "rs": 0, - "ag": 1, + "rs": tex.CommOverlapType.RS, + "ag": tex.CommOverlapType.AG, } @@ -75,7 +76,7 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--comm-type", type=partial(_mapped_argtype, typemap=nvte_comm_types), - default=0, + default=tex.CommOverlapType.AG, help="Comm type to overlap.", ) parser.add_argument( @@ -156,12 +157,10 @@ def _parse_args(argv=None, namespace=None): if opts.fp8: warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") opts.fp8 = False - elif opts.comm_type == 1: + elif opts.comm_type == tex.CommOverlapType.AG: if opts.atomic: setattr(opts, "atomic_rs_p2p", opts.p2p) - if not opts.p2p: - warnings.warn("All-gather overlap is only supported with point-2-point comms.") - opts.p2p = True + opts.p2p = True if opts.atomic: if not te.fp8.check_fp8_support(): @@ -283,35 +282,35 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None if WORLD_RANK == 0: print("\n", end="", flush=True) - ub_callbacks = ( - tex.UbufBootstrapCallbacks() + helper = ( + tex.CommOverlapHelper() if tex.ubuf_built_with_mpi() - else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg) + else tex.CommOverlapHelper(bootstrap_pg) ) - if opts.comm_type == 0: + if opts.comm_type == tex.CommOverlapType.RS: if opts.bulk_overlap: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS elif opts.p2p: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P ) else: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + tex.CommOverlapAlgo.ATOMIC_GEMM_RS if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ) - elif opts.comm_type == 1: + elif opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG else: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ) else: raise TypeError("Invalid comm+GEMM overlap type!") @@ -322,95 +321,55 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None hidden_size = opts.num_heads * opts.head_dim inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) - ubuf_dtype = torch.bfloat16 - if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output): - ubuf_dtype = torch.uint8 - sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") - ub_obj = ub_obj = ( - tex.UbufP2PCommOverlap( - sample_buffer, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + buffer_dtype = torch.bfloat16 + if ( + opts.fp8 + and not opts.bulk_overlap + and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output) + ): + buffer_dtype = torch.uint8 + ub_obj = ( + tex.CommOverlapP2P( + (outer_size, hidden_size), + buffer_dtype, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - opts.comm_type == 0 or opts.atomic, # Set SM margin - opts.aggregate, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - opts.comm_type == 0, # overlap with reduce scatter - opts.atomic, # use a single GEMM with atomic-counters - not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), - ub_callbacks, + opts.comm_type, + set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, + atomic_gemm=opts.atomic, + aggregate=opts.aggregate, + use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ) if opts.p2p - else tex.UbufCommOverlap( - sample_buffer, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + else tex.CommOverlap( + (outer_size, hidden_size), + buffer_dtype, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 16, # Number of communication SMs - 2, # CGA cluster size - 4, # Number of communication splits - True, # Set SM margin - 3, # Max concurrent GEMM streams - opts.atomic, # Use a single GEMM with atomic-counters - ub_callbacks, + atomic_gemm=opts.atomic, ) ) # Numerical check on AG + atomic GEMM requires testing an AG+RS pair ub_obj2 = None - if opts.atomic and opts.comm_type == 1 and opts.check_numerics: - sample_buffer2 = torch.empty( - (outer_size, hidden_size), - dtype=torch.uint8 if opts.fp8_output else torch.bfloat16, - device="cuda", - ) + if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics: ub_obj2 = ( - tex.UbufP2PCommOverlap( - sample_buffer2, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + tex.CommOverlapP2P( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - True, # Set SM margin - False, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - True, # overlap with reduce scatter - True, # use a single GEMM with atomic-counters - True, # use copy engine for P2P communications - ub_callbacks, + tex.CommOverlapType.RS, + set_sm_margin=True, + atomic_gemm=True, ) if opts.atomic_rs_p2p - else tex.UbufCommOverlap( - sample_buffer2, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + else tex.CommOverlap( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 16, # Number of communication SMs - 2, # CGA cluster size - 4, # Number of communication splits - True, # Set SM margin - 3, # Max concurrent GEMM streams - True, # uUe a single GEMM with atomic-counters - ub_callbacks, + atomic_gemm=True, ) ) @@ -426,12 +385,12 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None local_kernel_t_shape = (ffn_hidden_size, hidden_size) local_inp_shape = (outer_size, hidden_size) # Bulk overlap comm tensor is distributed for AG overlap only - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: bulk_inp_shape = (outer_size // tp_size, hidden_size) else: bulk_inp_shape = (outer_size, hidden_size) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) local_inp_shape = (outer_size // tp_size, hidden_size) @@ -472,7 +431,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None std=opts.std, ) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) ker_g = torch.transpose( te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 @@ -494,7 +453,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ).to(dtype=torch.float32) if opts.bulk_overlap: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0] else: # First all-gather all the bulk inputs into a list @@ -505,7 +464,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else: ref_g = torch.matmul(inp_g, ker_g) if ub_obj2 is not None: - inp2_g = torch.nn.functional.gelu(ref_g) + inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) if opts.fp8: @@ -529,7 +488,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_amax = torch.max(torch.abs(bulk_inp)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) elif ub_obj2 is not None: @@ -551,7 +510,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None kernel_t_fp8 = tex.cast_to_fp8( kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype ) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_inp_fp8 = tex.cast_to_fp8( bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype ) @@ -574,7 +533,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None rtol=0.125, atol=0.0675, ) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: torch.allclose( bulk_inp.to(dtype=torch.float32), bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], @@ -590,7 +549,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ) # Set Fp8 scales for userbuffers - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) if ub_obj2 is not None: ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) @@ -602,7 +561,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Set up comm/compute buffers ubuf_out2 = None rs_out2 = None - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: ub_obj.copy_input_to_ubuf(bulk_inp, 1) gemm_inp = inp @@ -686,9 +645,9 @@ def _fp8_gemm2(gemm1_out): gelu=False, use_split_accumulator=te.module.base._2X_ACC_FPROP, ub_algo=( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P if opts.atomic_rs_p2p - else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else tex.CommOverlapAlgo.ATOMIC_GEMM_RS ), ub=ub_obj2, extra_output_tensor=rs_out2, @@ -762,10 +721,14 @@ def _gemm(): avg_gpu_time = sum(gpu_times) / opts.timing_iters gemm_name = "".join( [ - "p2p all-gather + " if opts.comm_type == 1 else "", + "p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "", "atomic " if opts.atomic else "", "GEMM", - (f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""), + ( + f" + {'p2p ' if opts.p2p else ''}reduce-scatter" + if opts.comm_type == tex.CommOverlapType.RS + else "" + ), ] ) timing_info = ( @@ -781,7 +744,7 @@ def _gemm(): dist.barrier(tp_group) if opts.bulk_overlap: output_info = "" - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # Bulk overlap AG output is already gathered test_out = ub_obj.get_ubuf_output(1) else: @@ -794,7 +757,7 @@ def _gemm(): output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" dist_print( output_info, - src=0 if opts.comm_type == 0 else None, + src=0 if opts.comm_type == tex.CommOverlapType.RS else None, section=True, ) @@ -805,7 +768,7 @@ def _gemm(): ) dist_print(nonzero_info, src=0, section=True, group=tp_group) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: if ub_obj2 is not None: # AG+RS Output: (M/P, N) -> gather -> (M, N) output = rs_out2.to(dtype=torch.float32) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e5653bda01..e32a7ccb12 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -9,7 +9,6 @@ import socket import argparse import warnings -from functools import partial import torch import torch.distributed as dist diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 63310195ae..ce46a72189 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -42,6 +42,9 @@ # Force GPU kernels to launch in the order they're executed by the host CPU os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +# Clear torch.dynamo caches +torch._dynamo.reset() + def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): test_path = TEST_ROOT / "run_gemm_with_overlap.py" diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 15fb994050..2d863b3bba 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -11,16 +11,24 @@ import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.common.recipe import DelayedScaling dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} def run_dpa_with_cp( - dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p" + dtype="bf16", + model=None, + qkv_format="bshd", + kernel_backend="FlashAttention", + cp_comm_type="p2p", + fp8_mha=False, ): """Test DotProductAttention module with context parallelism""" + # args are passed as strings + fp8_mha = fp8_mha == "True" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": @@ -72,7 +80,7 @@ def run_dpa_with_cp( cp_comm_sub_groups.append(sub_group) if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True) + fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) # instantiate core attn module core_attn = DotProductAttention( @@ -201,7 +209,11 @@ def run_dpa_with_cp( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), ) - out.backward(dout) + if fp8_mha: + dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) + out.backward(dout_fp8) + else: + out.backward(dout) # run core_attn wit CP q_, k_, v_, dout_, *rest = [ @@ -269,7 +281,11 @@ def run_dpa_with_cp( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), ) - out_.backward(dout_) + if fp8_mha: + dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 4b4eecbf39..4e995dabb1 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -619,14 +619,14 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), + "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), + "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), @@ -644,6 +644,9 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" + config = model_configs[model] + if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: + pytest.skip("qkv_layout not applicable for MQA/GQA") pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs @@ -1353,8 +1356,6 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, config = model_configs_fp8_vs_f16[model] if _flash_attn_3_is_installed and not is_training: - if RoPE: - pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index ea30a4831f..1007d6aa34 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -113,7 +113,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) -def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): +@pytest.mark.parametrize("fp8_mha", [False, True]) +def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): @@ -153,6 +154,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) + if dtype != "fp8" and fp8_mha: + pytest.skip("Only fp8 works with fp8_mha=True!") subprocess.run( get_bash_arguments( @@ -162,6 +165,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, + fp8_mha=fp8_mha, ), check=True, ) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index f804754949..dccf81829e 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -4,6 +4,7 @@ from itertools import product import copy +from contextlib import nullcontext import pytest import torch @@ -175,6 +176,216 @@ def test_frozen_model(self): torch.testing.assert_close(ref_param, tst_param) + def gen_precision_aware_test( + self, + use_fp8_params, + param_dtype, + use_master_weights, + master_weight_dtype, + grad_dtype, + exp_avg_dtype, + exp_avg_sq_dtype, + model_rtol=None, + model_atol=None, + master_rtol=None, + master_atol=None, + skip_assert=False, + ): + build_model_context = nullcontext + build_model_context_args = {} + if use_fp8_params: + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + with build_model_context(**build_model_context_args): + model = MultiheadAttention( + hidden_size=1024, + num_attention_heads=16, + layer_number=1, + params_dtype=param_dtype, + fuse_qkv_params=True, + ).cuda() + + ref_params = [] + model_params = [] + + for p in model.parameters(): + if p.requires_grad: + ref_params.append(p.detach().clone().float()) + model_params.append(p) + + options = { + "lr": 1, + "betas": (0.1, 0.25), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + } + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam( + model_params, + master_weights=use_master_weights, + master_weight_dtype=master_weight_dtype, + exp_avg_dtype=exp_avg_dtype, + exp_avg_sq_dtype=exp_avg_sq_dtype, + use_decoupled_grad=True, + **options, + ) + + def test_one_iteration(ref_optimizer, tst_optimizer): + for p_ref, p in zip(ref_params, model_params): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone().to(grad_dtype) + ref_optimizer.step() + tst_optimizer.step() + if use_master_weights: + master_weights_to_fp32 = [ + tst_optim.get_unscaled_state(p, "master_param") for p in model_params + ] + if not skip_assert: + torch.testing.assert_close( + ref_params, + master_weights_to_fp32, + rtol=master_rtol, + atol=master_atol, + equal_nan=True, + ) + ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params] + if not skip_assert: + torch.testing.assert_close( + ref_params_to_model_dtype, + model_params, + rtol=model_rtol, + atol=model_atol, + equal_nan=True, + ) + + for i in range(self.iters): + test_one_iteration(ref_optim, tst_optim) + + state_dict = tst_optim.state_dict() + tst_optim = te.optimizers.FusedAdam( + model_params, + master_weights=use_master_weights, + master_weight_dtype=master_weight_dtype, + exp_avg_dtype=exp_avg_dtype, + exp_avg_sq_dtype=exp_avg_sq_dtype, + use_decoupled_grad=True, + **options, + ) + tst_optim.load_state_dict(state_dict) + + for i in range(self.iters): + test_one_iteration(ref_optim, tst_optim) + + def test_fp32_no_master(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.float32, + use_master_weights=False, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp32_master(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp16_master(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.half, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_grad(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.bfloat16, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp16_exp_avg(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.half, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_exp_avg(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.uint8, + exp_avg_sq_dtype=torch.float32, + master_rtol=1e-2, + master_atol=1e-2, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp16_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.half, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.uint8, + skip_assert=True, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") def test_bf16_model_weight_cast(self): dtype = torch.bfloat16 @@ -186,12 +397,10 @@ def test_bf16_model_weight_cast(self): fuse_qkv_params=True, ).cuda() ref_params = [] - master_params = [] model_params = [] for p in model.parameters(): if p.requires_grad: ref_params.append(p.detach().clone().float()) - master_params.append(p.detach().clone().float()) model_params.append(p) options = { "lr": 5e-4, @@ -201,12 +410,17 @@ def test_bf16_model_weight_cast(self): "amsgrad": False, } ref_optim = torch.optim.Adam(ref_params, **options) - tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + tst_optim = te.optimizers.FusedAdam( + model_params, master_weights=True, use_decoupled_grad=True, **options + ) for i in range(self.iters): - self.gen_grad(ref_params, master_params) + for p_ref, p in zip(ref_params, model_params): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone() ref_optim.step() tst_optim.step() + master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params] torch.testing.assert_close(ref_params, master_params) model_params_to_fp32 = [p.float() for p in model_params] torch.testing.assert_close( @@ -225,12 +439,10 @@ def test_fp8_model_weight_cast(self): fuse_qkv_params=True, ).cuda() ref_params = [] - master_params = [] model_params = [] for p in model.parameters(): if p.requires_grad: ref_params.append(p.detach().clone().float()) - master_params.append(p.detach().clone().float()) model_params.append(p) options = { "lr": 5e-4, @@ -240,12 +452,17 @@ def test_fp8_model_weight_cast(self): "amsgrad": False, } ref_optim = torch.optim.Adam(ref_params, **options) - tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + tst_optim = te.optimizers.FusedAdam( + model_params, master_weights=True, use_decoupled_grad=True, **options + ) for i in range(self.iters): - self.gen_grad(ref_params, master_params) + for p_ref, p in zip(ref_params, model_params): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone() ref_optim.step() tst_optim.step() + master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params] torch.testing.assert_close(ref_params, master_params) model_params_to_fp32 = [p.float() for p in model_params] torch.testing.assert_close( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c0f45ada4e..c237dbaeb6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -64,6 +64,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq model_configs = { + "small": ModelConfig(128, 1e-5, 8, 36, 4, 128), "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), } @@ -110,23 +111,30 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: def assert_allclose( - l1: List[torch.Tensor], - l2: List[torch.Tensor], - atol: float, + l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None ) -> bool: """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." for i, (t1, t2) in enumerate(zip(l1, l2)): - result = torch.allclose(t1, t2, atol=atol) + tols = dict(atol=atol) + if rtol is not None: + tols["rtol"] = rtol + result = torch.allclose(t1, t2, **tols) if not result: - diff = torch.abs(t1 - t2).flatten() - m = torch.argmax(diff) - msg = ( - f"Outputs not close enough in tensor at idx={i}. " - f"Location of the maximum difference: {m.item()} " - f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " - f"(diff {diff[m].item()})." - ) + diff = torch.abs(t1 - t2) + tol = atol + (rtol * torch.abs(t2)) + exceed_mask = diff > tol + if exceed_mask.any(): + indices = torch.nonzero(exceed_mask, as_tuple=True) + max_diff = diff[exceed_mask].max() + max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] + max_location = [idx[max_idx].item() for idx in indices] + msg = ( + f"Outputs not close enough in tensor at idx={i}. " + f"Maximum difference at location {max_location} " + f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " + f"(diff {max_diff.item()})." + ) raise AssertionError(msg) @@ -526,7 +534,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): @@ -631,7 +639,7 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) @@ -764,7 +772,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) @@ -809,7 +817,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] @@ -868,11 +876,25 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config) torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config) + atol = { + torch.float32: 5e-3, + torch.half: 5e-2, + torch.bfloat16: 1e-1, + } + # Check output. - if dtype == torch.float32: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) - else: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + + # Check gradients, only for small model + if model == "small": + atol[torch.float32] = 5e-2 + rtol = { + torch.float32: 1e-2, + torch.half: 1e-2, + torch.bfloat16: 1e-2, + } + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @@ -906,7 +928,7 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] @@ -947,6 +969,21 @@ def test_mha_accuracy(dtype, bs, model, mask_type): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + # Check gradients, only for small model + if model == "small": + atol = { + torch.float32: 5e-2, + torch.half: 5e-2, + torch.bfloat16: 5e-2, + } + rtol = { + torch.float32: 1e-2, + torch.half: 1e-2, + torch.bfloat16: 1e-2, + } + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + def _test_granular_accuracy(block, bs, dtype, config): reset_rng_states() @@ -1002,7 +1039,7 @@ def _test_dpa_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_dpa_accuracy(dtype, bs, model): config = model_configs[model] @@ -1034,10 +1071,13 @@ def test_dpa_accuracy(dtype, bs, model): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) def test_linear_accuracy(dtype, bs, model): config = model_configs[model] @@ -1066,15 +1106,20 @@ def test_linear_accuracy(dtype, bs, model): torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config) # Check output. - if dtype == torch.float32: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) - else: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + if model == "small": + tolerance = 5e-3 if dtype == torch.float32 else 5e-2 + rtol = { + torch.float32: 1.3e-6, + torch.half: 1e-2, + torch.bfloat16: 2e-2, + } + for te_output, torch_output in zip(te_outputs, torch_outputs): + assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7]) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): @@ -1102,18 +1147,29 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config) - # Check output. atol = { torch.float32: 1e-7, torch.half: 2e-3, torch.bfloat16: 2e-2, } + + # Check output. assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + atol[torch.float32] = 2e-3 + rtol = { + torch.float32: 1.3e-6, + torch.half: 1e-3, + torch.bfloat16: 1.6e-2, + } + # Check gradients + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7]) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): @@ -1142,18 +1198,29 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config) - # Check output. atol = { torch.float32: 1e-7, torch.half: 2e-3, torch.bfloat16: 2e-2, } + + # Check output. assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + rtol = { + torch.float32: 1.3e-6, + torch.half: 1e-3, + torch.bfloat16: 1.6e-2, + } + atol[torch.float32] = 1e-4 + # Check gradients + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma): @@ -1195,18 +1262,34 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) - # Check output. atol = { torch.float32: 2.5e-4, torch.half: 2e-3, torch.bfloat16: 2e-2, } + + # Check output. assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + if model == "small": + atol = { + torch.float32: 1e-3, + torch.half: 5e-2, + torch.bfloat16: 5e-2, + } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } + # Check gradients + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("normalization", all_normalizations) def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): @@ -1246,11 +1329,26 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config) + atol = { + torch.float32: 2e-2, + torch.half: 5e-2, + torch.bfloat16: 5e-2, + } + # Check output. - if dtype == torch.float32: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) - else: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + + # Check gradients, only for small model + rtol = { + torch.float32: 1e-3, + torch.half: 1e-2, + torch.bfloat16: 4e-2, + } + atol[torch.half] = 2e-1 + atol[torch.bfloat16] = 2e-1 + if model == "small": + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): @@ -1301,7 +1399,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_grouped_linear_accuracy( @@ -1361,7 +1459,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): dtype=torch.float32, num_gemms=6, bs=2, - model=list(model_configs.keys())[0], + model="126m", fp8=True, fp8_model_params=True, parallel_mode=parallel_mode, @@ -1374,7 +1472,7 @@ def test_grouped_linear_accuracy_single_gemm(): dtype=torch.float32, num_gemms=1, bs=2, - model=list(model_configs.keys())[0], + model="126m", fp8=True, fp8_model_params=True, ) @@ -1475,7 +1573,7 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( @@ -1594,7 +1692,7 @@ def train_step(): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_gpt_cuda_graph(dtype, bs, model): config = model_configs[model] @@ -1686,7 +1784,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_gpt_fp8_parameters(dtype, bs, model): if not fp8_available: pytest.skip(reason_for_no_fp8) @@ -1710,7 +1808,7 @@ def test_gpt_fp8_parameters(dtype, bs, model): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_transformer_layer_hidden_states_format(dtype, bs, model): config = model_configs[model] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cabb2e2aea..ca23008edd 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_rope/fused_rope.cu - recipe/delayed_scaling.cu) + recipe/delayed_scaling.cu + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -93,6 +97,15 @@ target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) +if (NVTE_UB_WITH_MPI) + find_package(MPI REQUIRED) + target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) + target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..59ec56f161 --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -0,0 +1,980 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "userbuffers/userbuffers.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +using namespace std::placeholders; + +namespace transformer_engine { + +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ + +bool ubuf_built_with_mpi() { +#ifdef NVTE_UB_WITH_MPI + return true; +#else + return false; +#endif +} + +CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm) { + // Initialize userbuf communicator + if (!_comm_created) { + if (myrank == 0) { + printf("!!! [UB] Create Userbuffers Communicator\n"); + } +#ifdef NVTE_UB_WITH_MPI + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); +#else + create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + allgather_handle, barrier_handle, 1, 1, tp_size, 1); +#endif + _comm_created = true; + } + _use_ce = static_cast(use_ce); + _num_comm_sm = num_comm_sm; + _cga_size = comm_cga_size; + + for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1)); + _stream_compute.push_back(std::move(stream)); + } + + _num_splits = num_splits; + _rank = _ub_comm->myrank; + _tp_size = tp_size; + _tp_id = _rank % _tp_size; + + // Set the number of SMs for GEMM with margin + int sm_count = transformer_engine::cuda::sm_count(); + _math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + void *counter_ptr; + size_t counter_bytes = _num_splits * 2 * sizeof(int32_t); + NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); + _counter = TensorWrapper(counter_ptr, std::vector{static_cast(_num_splits * 2)}, + DType::kInt32); + } + // CUDA event creation + cudaEventCreateWithFlags(&_start_compute, 0); + cudaEventCreateWithFlags(&_stop_compute, 0); + cudaEventCreateWithFlags(&_start_comm, 0); + cudaEventCreateWithFlags(&_stop_comm, 0); +} + +CommOverlapCore::~CommOverlapCore() { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + + if (_atomic_gemm) cudaFree(_counter.dptr()); + + for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + + if (_comm_created) { +#ifdef NVTE_UB_WITH_MPI + destroy_communicator_mpi(_ub_comm); +#else + destroy_communicator(_ub_comm); +#endif + _comm_created = false; + } +} + +/*************************************************************************************************** + * Comm+GEMM Overlap Base (Pipelined / Collective) + **************************************************************************************************/ + +CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool atomic_gemm) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, + num_comm_sm, set_sm_margin, false, atomic_gemm) { + _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); + NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, + "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", + "or 2 (multi-atomic)."); + + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); + + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); +} + +CommOverlapBase::~CommOverlapBase() { + cudaEventDestroy(_start_d2dcopy); + cudaStreamDestroy(_stream_comm); +} + +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication: AG and RS + int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size + if (comm_type == CommOverlapType::AG) { + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + } else { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + comm_elements *= 2; + assert(rs_output.numel() == _ubuf.numel() / _tp_size); + assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); + assert(rs_output.element_size() == 2); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, + comm_elements, _ub_comm, _stream_comm); + } else { + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + } + } + + assert(pre_gelu_out.numel() == 0); + nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, + grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, + stream_main); + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions + size_t m = A.size(0); + size_t k = A.size(1); + size_t n = B.size(0); + size_t m_chunk = m / _num_splits; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _num_splits, false, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); + auto workspace_chunk = + TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), + _stream_compute[0]); + + for (int i = 0; i < _num_splits; i++) { + if (_rs_kernel_type == 1) { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_atomic_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + &counter_ptr[i], _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _num_splits, &counter_ptr[i], _ub_comm, + _stream_comm); + } + } else if (_rs_kernel_type == 2) { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_multiatomic_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + counter_ptr, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, + _num_splits, counter_ptr, _ub_comm, + _stream_comm); + } + break; + } else { + consumer(counter_ptr, i, _stream_comm); + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, _ubuf_scale_inv, + _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main) { + // Get GEMM dimensions + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t m = A.size(0); + size_t k = A.size(1); + size_t n = B.size(0); + size_t m_chunk = m / _num_splits; + size_t input_a_chunk_size = m_chunk * k; + size_t output_chunk_size = n * m_chunk; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + if (gemm_overlap) { + auto input_a_chunk = + TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); + auto output_chunk = + TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto workspace_chunk = TensorWrapper( + workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[0]); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * D.element_size(); + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + + input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, + A.dtype(), nullptr, nullptr, A.scale_inv()); + output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, + D.dtype(), D.amax(), D.scale(), nullptr); + workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Last communication chunk with max SM + _ub_comm->sms = UB_MAX_SM; + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } else { + for (int i = 0; i < _num_splits; i++) { + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + + auto input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, + A.dtype(), nullptr, nullptr, A.scale_inv()); + auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), + {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + } + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::split_overlap_rs + +/*************************************************************************************************** + * Comm+GEMM Overlap P2P Base (Ring-Exchange) + **************************************************************************************************/ + +CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + _is_p2p = true; + _is_reduce_scatter = comm_type == CommOverlapType::RS; + _aggregate = aggregate; + + // Create workspace tensor with userbuffer + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + int buffer_chunk_bytes = buffer_bytes / tp_size; + _num_ubuf_chunks = tp_size; + if (_is_reduce_scatter) { + // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk + // outputs for reduction at the end of the pipelining. + buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); + _num_ubuf_chunks = tp_size * 2 - 1; + } + + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + buffer_dtype); + + // Create tensor chunks for easy management + char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); + for (int i = 0; i < _num_ubuf_chunks; i++) { + _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), + {buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype)); + ubuf_byte_ptr += buffer_chunk_bytes; + } + + _rank_round_tp = (_rank / _tp_size) * _tp_size; + _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; + _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; + + _self_chunk_id = _tp_id; + if (_atomic_gemm && !_is_reduce_scatter) { + _use_multiatomic_ag = getenv("NVTE_AG_P2P_MULTI_ATOMIC"); + if (_use_multiatomic_ag) { + _use_ce = 0; + _ub_comm->push = 1; + if (_rank == 0) { + printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); + } + } + _self_chunk_id = 0; + NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); + } + + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); +} + +CommOverlapP2PBase::~CommOverlapP2PBase() { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaStreamDestroy(_stream_recv); + cudaStreamDestroy(_stream_send); +} + +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t n = _ubuf.size(0); + const size_t n_chunk = n / _tp_size; + assert(pre_gelu_out.numel() == 0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Create an GEMM output buffer with N+1 chunks in a contiguous memory + void *D_buffer_ptr; + int D_chunk_bytes = n_chunk * m * D.element_size(); + NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, true, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = + TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + + for (int i = 0; i < _tp_size - 1; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = i; + int recv_chunk_id = i + 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + if (_use_multiatomic_ag) { + if (i == 0) { + _ub_comm->use_ce = 0; + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, + true, _stream_recv); + } + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, + _stream_recv); + producer(counter_ptr, recv_chunk_id, _stream_recv); + } + if (i == 0) { + nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false, + _counter.data(), stream_main); + } + } + + // Store the input activation for backprop + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); + assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); + NVTE_CHECK_CUDA( + cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), + _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + } + + // Copy the first GEMM output chunk to the end chunk position of D_buffer + char *src_ptr = reinterpret_cast(D_buffer.dptr()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes, + cudaMemcpyDeviceToDevice, stream_main)); + + // Return the last N rows of D_buffer + NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), + cudaMemcpyDeviceToDevice, stream_main)); + + // Clean up buffer allocation + NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main)); + + _ub_comm->sms = ori_sms; +} // CommOverlapP2PBase::atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const bool do_gelu = pre_gelu_out.numel() > 0; + const int output_chunk_bytes = (n_chunk * m) * D.element_size(); + const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; + + // Get output and workspace data pointers + char *output_ptr = reinterpret_cast(D.dptr()); + char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + if (_aggregate) { + const int num_steps = _tp_size / 2; + char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + + // Initial 1X input chunk exchange between neighboring peers + int send_chunk_id = _tp_id; + int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); + + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + char *input_b_chunk_ptr = input_b_ptr + send_offset; + auto input_b_chunk = + TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); + + char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), + {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); + + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = + (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, + pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < num_steps - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + } + } + } else { + for (int i = 0; i < _tp_size; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); + + char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), {n_chunk, m}, + D.dtype(), D.amax(), D.scale(), nullptr); + + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, + pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < _tp_size - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + } + } + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // CommOverlapP2PBase::split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Reset counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, false, stream_main); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + // Atomic GEMM + // Process GEMM chunks in the order that AG+GEMM places the output chunks. + auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = + TensorWrapper(workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), + stream_main); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + + consumer(counter_ptr, send_chunk_id, _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + _ub_comm->sms = ori_sms; +} + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t k = A.size(1); + size_t n = B.size(0); + + // Get communication and GEMM input chunk sizes + size_t n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); + + auto input_b_chunk = TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, + B.dtype(), nullptr, nullptr, B.scale_inv()); + + auto output_chunk = + TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); + + char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); + + if (i > 0) { + // P2P communication chunk + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + } + + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + + _ub_comm->sms = ori_sms; +} + +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc rename to transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h rename to transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp similarity index 92% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index e2628f6a31..6f3eef3d28 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -20,7 +20,9 @@ #include #include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" #include "common/util/logging.h" +#include "common/util/system.h" #include "ipcsocket.h" #include "userbuffers.h" @@ -44,31 +46,19 @@ static MPI_Comm EXT_COMM_INTER; } while (false) void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - ExtComm group) { - // UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, - // globaldata, globalbytes, MPI_BYTE, - // static_cast(group))); - MPI_Comm comm = static_cast(group); + ExtComm comm) { int numranks; UB_MPI_CHECK(MPI_Comm_size(comm, &numranks)); assert(globalbytes == numranks * localbytes); - - int myrank; - UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank)); - char *globaltarget = reinterpret_cast(globaldata) + (myrank * localbytes); - memcpy(globaltarget, localdata, localbytes); - - for (int n = 0; n < numranks; n++) { - globaltarget = reinterpret_cast(globaldata) + (n * localbytes); - UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm)); - } + UB_MPI_CHECK( + MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm)); } -void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast(group))); } +void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } #else -static char EXT_COMM_WORLD[] = "world"; -static char EXT_COMM_INTRA[] = "intra"; -static char EXT_COMM_INTER[] = "inter"; +#define EXT_COMM_WORLD "world" +#define EXT_COMM_INTRA "intra" +#define EXT_COMM_INTER "inter" #endif #define MULTICAST_GB_TOTAL 512 @@ -106,11 +96,10 @@ int pipe_rank(communicator *comm, int step) { return newnode * numlocal + newlocal; } -int create_communicator_grouped2( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, - int tensornodes) { +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes) { *comm = new communicator(); (*comm)->comm_world = EXT_COMM_WORLD; @@ -214,8 +203,11 @@ int create_communicator_grouped2( (*comm)->asyncblocks = 16; #define NBUF 2 - if ((*comm)->sm_arch >= 9 && (*comm)->ar2_nvsize > 1 && - !getenv("UB_SKIPMC")) { // multicast init only for TP ops (____2 operations) + +#if CUDART_VERSION >= 12010 + if (!transformer_engine::getenv("UB_SKIPMC") && + transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { + // multicast init only for TP ops (____2 operations) size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30); (*comm)->mc_offset = 0; (*comm)->use_mc = 1; @@ -291,20 +283,20 @@ int create_communicator_grouped2( (*comm)->_barrier((*comm)->comm_world); if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); } else { +#endif if (!(*comm)->myrank) printf("MC NOT initialized and used\n"); (*comm)->mc_maxsize = 0; (*comm)->mc_offset = 0; (*comm)->use_mc = 0; +#if CUDART_VERSION >= 12010 } +#endif #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) // peer pointers + op flags + comm buffer - NVTE_CHECK_CUDA( - cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet - NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false); + register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); @@ -346,18 +338,17 @@ int create_communicator_grouped2( return 0; } -int create_communicator_grouped( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes) { +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); } int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, - std::function ext_allgather, - std::function ext_barrier) { + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, 1, 1, 1, 1); } @@ -428,7 +419,7 @@ int create_communicator_mpi(communicator **comm) { void destroy_communicator(communicator *comm) { for (int hndl = 0; hndl < comm->free_region; hndl++) { - if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) { + if (comm->use_mc && comm->mem_dealloc[hndl]) { for (int rank = 0; rank < comm->nvsize; rank++) { if (rank == comm->nvrank) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); @@ -479,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->memflags[hndl] = 0; comm->mem_dealloc[hndl] = alloc; +#if CUDART_VERSION >= 12010 if (comm->use_mc && alloc) { int nranks = comm->nvsize; // total GPUs in NVLINK domain int myrank = comm->nvrank; @@ -594,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * } } else { +#endif if (alloc) { NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes)); NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes)); @@ -624,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * NVTE_CHECK_CUDA(cudaDeviceSynchronize()); free(tmp); +#if CUDART_VERSION >= 12010 } +#endif comm->mem_size[hndl] = aligned_size; comm->mem_ptr[hndl] = *gpubuff; diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu similarity index 98% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 0cd2a0253b..26843d8107 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) -#if __CUDA_ARCH__ >= 900 +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 // All MC kernels here template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, const int numlines, void **commbuff, const int handleridx, - float4 *mc_ptr) { + float4 *mc_ptr, const uint64_t ub_timeout) { int *flagptr, physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { + if (clock64() - s > ub_timeout) { UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, *flag); break; @@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { + if (clock64() - s > 2ull * ub_timeout) { UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, *flag); break; @@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, const int numlines, void **commbuff, const int handleridx, - float4 *mc_ptr) {} + float4 *mc_ptr, const uint64_t ub_timeout) {} template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, @@ -2496,6 +2496,18 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i } } +// reset counters kernel +static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) { + if (blockIdx.x == 0 && threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < num_chunks; i++) { + ((unsigned int *)atomic_ptr)[i] = 1; + ((unsigned int *)atomic_ptr)[i + num_chunks] = 0; + } + if (allgather) ((unsigned int *)atomic_ptr)[0] = 0; + } +} + void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); @@ -2514,6 +2526,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); } +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { + dim3 block(1); + dim3 grid(1); + reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); +} + template __global__ void __launch_bounds__(MAX_THREADS / 4) reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, @@ -2546,3 +2564,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); + +__global__ void __launch_bounds__(MAX_THREADS / 4) + reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) { + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + half *inputs_half = reinterpret_cast(inputs); + float accum_buf = static_cast(inputs_half[tid]); +#pragma unroll + for (int i = 1; i < num_inputs; i++) { + accum_buf += static_cast(inputs_half[tid + input_size * i]); + } + half *output_half = reinterpret_cast(output); + output_half[tid] = (half)accum_buf; +} + +void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) { + size_t num_threads = MAX_THREADS / 4; + size_t num_blocks = (input_size + num_threads - 1) / num_threads; + dim3 block(num_threads); + dim3 grid(num_blocks); + reduce_bf16_cuda<<>>(inputs, output, num_inputs, input_size); +} diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h similarity index 90% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers.h rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 371932f446..57e68afce0 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -19,11 +19,14 @@ #ifdef NVTE_UB_WITH_MPI #include -typedef MPI_Comm ExtComm; +#define ExtComm MPI_Comm #else -typedef char *ExtComm; +#define ExtComm const char * #endif +using ExtAllgatherOp = std::function; +using ExtBarrierOp = std::function; + #define NVTE_MAX_REGIONS 16 #define NVTE_MAX_SMS 32 #define NVTE_MAX_OPS 32 @@ -142,12 +145,12 @@ struct communicator { volatile int tail; // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) - std::function _allgather; - std::function _barrier; + ExtAllgatherOp _allgather; + ExtBarrierOp _barrier; - ExtComm comm_world, - comm_inter, // reduction group communicator (subset of the nodes) along GPU rail - comm_intra; // full intranode (all ndev GPUS) + ExtComm comm_world; + ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail + ExtComm comm_intra; // full intranode (all ndev GPUS) #ifdef NVTE_UB_WITH_MPI MPI_Request mpihndl[NVTE_MAX_SHARP]; #endif @@ -161,23 +164,22 @@ typedef struct communicator communicator; void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream); +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream); /* creates communicator, allocates all internal buffers if necessary */ -int create_communicator_grouped2( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, - int tensornodes); +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes); -int create_communicator_grouped( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes); +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes); int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, - std::function ext_allgather, - std::function ext_barrier); + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier); int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, int tensornodes); @@ -314,4 +316,6 @@ template void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); +void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index afa8a9c58d..4ea0ea5741 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -183,8 +183,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // qkv format ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && - (cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 90600))) && + ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || + (cudnn_runtime_version >= 90600)))) && // sliding window ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || @@ -272,6 +272,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); @@ -292,7 +297,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, + b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -349,6 +354,11 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); @@ -377,7 +387,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -442,6 +452,13 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else { NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -463,9 +480,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -526,6 +543,13 @@ void nvte_fused_attn_bwd_kvpacked( } else { NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -556,9 +580,9 @@ void nvte_fused_attn_bwd_kvpacked( input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV, - input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -616,6 +640,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t h_kv = input_K->data.shape[ndim - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -637,9 +668,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, - input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -696,6 +727,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t h_kv = input_K->data.shape[ndim - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -726,10 +764,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, - output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, + output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 176ec50cd0..1a555a4999 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -49,14 +49,14 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, - void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, + void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -73,10 +73,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); - if (is_ragged) { + const auto cudnn_runtime_version = cudnnGetVersion(); + + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if (is_ragged && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = max_t_q; + s_kv = max_t_kv; } - const auto cudnn_runtime_version = cudnnGetVersion(); const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { @@ -117,6 +125,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // offset_k std::shared_ptr, // offset_v std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -140,30 +149,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr bias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -175,6 +164,21 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_QKV_Matrix::NVTE_V_Matrix); if (is_ragged) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -268,6 +272,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); O->set_output(true) .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) @@ -276,10 +285,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); } - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); + if (is_ragged && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, 1, h, 1}) + .set_ragged_offset(offset_stats); + } else { + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); + } std::tuple, // Q std::shared_ptr, // K @@ -291,8 +314,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + ? std::make_tuple(offset_stats) + : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -302,15 +328,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, offset_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k, - offset_v, offset_o, dropout_seed, dropout_offset] = + offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -318,10 +345,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // We do this by adding padding at the end of each separate allocation. auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); - const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen; + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); - const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged) { + if (cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + } + } if (workspace == nullptr) { *workspace_size = plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; @@ -348,7 +382,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), + actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); variant_pack[seq_q] = devActualSeqlenQ; @@ -363,15 +397,22 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsS = nullptr; + if (cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, - devOffsetsV, devOffsetsO); + devOffsetsV, devOffsetsO, devOffsetsS); variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO; + if (cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; + } } if (is_dropout) { @@ -386,12 +427,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, - void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, + void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -414,6 +456,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if (is_ragged && cudnn_runtime_version >= 90600) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = max_t_q; + s_kv = max_t_kv; + } // We choose between 32-bit and 64-bit offsets depending on need. // This allows us to support older cuDNN runtimes gracefully. @@ -462,6 +514,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // offset_k std::shared_ptr, // offset_v std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -485,29 +538,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr q, k, v, o, dO, stats, attn_scale; std::shared_ptr bias, dBias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -522,6 +556,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -569,11 +623,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); } - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + if (is_ragged && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, 1, h, 1}) + .set_data_type(fe::DataType_t::FLOAT) + .set_ragged_offset(offset_stats)); + } else { + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") @@ -589,6 +658,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + if (is_ragged && cudnn_runtime_version >= 90600) { + sdpa_backward_options.set_max_total_seq_len_q(s_q); + } + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_sliding_window_length(window_size_left + 1); } @@ -682,8 +755,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + ? std::make_tuple(offset_stats) + : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -693,15 +769,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, offset_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, + offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] = + offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -709,10 +786,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( // We do this by adding padding at the end of each separate allocation. auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); - const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen; + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); - const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged) { + if (cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + } + } if (workspace == nullptr) { *workspace_size = plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; @@ -752,7 +836,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), + actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); variant_pack[seq_q] = devActualSeqlenQ; @@ -767,15 +851,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsS = nullptr; + if (cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, - devOffsetsV, devOffsetsO); + devOffsetsV, devOffsetsO, devOffsetsS); variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO; + if (cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; + } } if (is_dropout) { @@ -792,10 +883,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -803,6 +894,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( const auto QKV_type = input_QKV->data.dtype; void *devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { stride = typeToSize(QKV_type) * num_attn_heads * head_dim; @@ -821,17 +913,30 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens = get_max_tokens(num_tokens); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -845,7 +950,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -875,12 +984,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, - bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, - devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, - handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, + max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -898,10 +1007,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( } void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -909,7 +1018,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( const auto QKV_type = input_QKV->data.dtype; void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { @@ -934,6 +1042,14 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens = get_max_tokens(num_tokens); + } + void *devPtrdQKV = output_dQKV->data.dptr; void *devPtrdQ = devPtrdQKV; void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); @@ -952,12 +1068,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, - bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, + max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, + devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, + devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -975,19 +1092,21 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( } void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; void *devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; @@ -1005,6 +1124,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -1013,12 +1133,26 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1032,7 +1166,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1063,11 +1201,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1086,12 +1224,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1122,6 +1261,16 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + void *devPtrdQ = output_dQ->data.dptr; void *devPtrdKV = output_dKV->data.dptr; void *devPtrdK = devPtrdKV; @@ -1143,12 +1292,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1167,8 +1316,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -1177,6 +1327,7 @@ void fused_attn_arbitrary_seqlen_fwd( using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); void *devPtrQ = input_Q->data.dptr; void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; @@ -1196,12 +1347,26 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1215,7 +1380,11 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1246,11 +1415,11 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1269,13 +1438,13 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1296,6 +1465,16 @@ void fused_attn_arbitrary_seqlen_bwd( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + void *devPtrdQ = output_dQ->data.dptr; void *devPtrdK = output_dK->data.dptr; void *devPtrdV = output_dV->data.dptr; @@ -1315,12 +1494,12 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4b523cca1a..3a1216f891 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -19,47 +19,50 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -68,13 +71,13 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 88c1490c01..d3422de481 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -746,7 +746,7 @@ void fused_attn_max_512_fwd_impl( void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), + b, b, static_cast(devPtrCuSeqlenQ), static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -1169,7 +1169,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), + b, b, static_cast(devPtrCuSeqlenQ), static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 7f76dcad77..ca00218d9a 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include "../common.h" #include "transformer_engine/fused_attn.h" @@ -353,66 +354,75 @@ __global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t * } // convert cu_seqlens to actual_seqlens -__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, +__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, + int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t *kv_seqlens) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b) { + if (tid < actual_b) { q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid]; kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; + } else if (tid < max_b) { + q_seqlens[tid] = 0; + kv_seqlens[tid] = 0; } } // convert cu_seqlens_padded to offsets template -__device__ void cu_seqlens_padded_to_offsets_impl(NVTE_QKV_Layout_Group layout_group, int64_t b, - int64_t h, int64_t hg, int64_t d_qk, int64_t d_v, - const int32_t *cu_seqlens_q_padded, - const int32_t *cu_seqlens_kv_padded, - OFFSETS_T *offsets_q, OFFSETS_T *offsets_k, - OFFSETS_T *offsets_v, OFFSETS_T *offsets_o) { +__device__ void cu_seqlens_padded_to_offsets_impl( + NVTE_QKV_Layout_Group layout_group, int64_t actual_b, int64_t max_b, int64_t h, int64_t hg, + int64_t d_qk, int64_t d_v, const int32_t *cu_seqlens_q_padded, + const int32_t *cu_seqlens_kv_padded, OFFSETS_T *offsets_q, OFFSETS_T *offsets_k, + OFFSETS_T *offsets_v, OFFSETS_T *offsets_o, OFFSETS_T *offsets_s) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b + 1) { - offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid]; + auto cu_seqlens_id = min(tid, actual_b); + if (tid <= max_b) { + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; + if (offsets_s != nullptr) { + offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; + } switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = offsets_q[tid]; - offsets_v[tid] = offsets_q[tid]; + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = offsets_q[cu_seqlens_id]; + offsets_v[tid] = offsets_q[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = offsets_k[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = offsets_k[cu_seqlens_id]; break; } } } -__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b, - int64_t h, int64_t hg, int64_t d_qk, int64_t d_v, - const int32_t *cu_seqlens_q_padded, +__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b, + int64_t max_b, int64_t h, int64_t hg, int64_t d_qk, + int64_t d_v, const int32_t *cu_seqlens_q_padded, const int32_t *cu_seqlens_kv_padded, DType offset_dtype, void *offsets_q, void *offsets_k, - void *offsets_v, void *offsets_o) { + void *offsets_v, void *offsets_o, void *offsets_s) { if (offset_dtype == DType::kInt32) { cu_seqlens_padded_to_offsets_impl( - layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, + layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, reinterpret_cast(offsets_q), reinterpret_cast(offsets_k), - reinterpret_cast(offsets_v), reinterpret_cast(offsets_o)); + reinterpret_cast(offsets_v), reinterpret_cast(offsets_o), + reinterpret_cast(offsets_s)); } else { assert(offset_dtype == DType::kInt64 && "expect int64"); cu_seqlens_padded_to_offsets_impl( - layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, + layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, reinterpret_cast(offsets_q), reinterpret_cast(offsets_k), - reinterpret_cast(offsets_v), reinterpret_cast(offsets_o)); + reinterpret_cast(offsets_v), reinterpret_cast(offsets_o), + reinterpret_cast(offsets_s)); } } @@ -450,6 +460,40 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at return DType::kInt32; } +// quantize batch size +size_t get_max_batch_size(size_t batch_size) { + size_t max_b = batch_size; + size_t log2_b = ceil(log2(batch_size)); + // batch size is expected to be 10s-100s + // b = 1, ..., 32 -> max_b = 32 + // b = 33, ..., 512 -> max_b = next power of 2 + // otherwise -> max_b = b + if (log2_b <= 5) { + max_b = 32; + } else if (log2_b <= 9) { + max_b = pow(2, log2_b); + } + return max_b; +} + +// quantize token count +size_t get_max_tokens(size_t num_tokens) { + // token count is expected to be 1k's-100k's + // t = 0, ..., 1024 -> max_t = 1024 + // t = 1025, ..., 32k -> max_t = next power of 2 + // t = 32k+1, ... -> max_t = increment by 32k + size_t log2_t = ceil(log2(num_tokens)); + size_t max_t = 0; + if (log2_t <= 10) { + max_t = 1024; + } else if (log2_t <= 15) { + max_t = pow(2, log2_t); + } else { + max_t = (num_tokens + 32767) / 32768 * 32768; + } + return max_t; +} + } // namespace fused_attn // get cuDNN data type diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index bea7ed05dd..c060c4907d 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -122,21 +122,24 @@ __global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t * int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); -__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, +__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, + int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t *kv_seqlens); -__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b, - int64_t h, int64_t hg, int64_t d_qk, int64_t d_v, - const int32_t *cu_seqlens_q_padded, +__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b, + int64_t max_b, int64_t h, int64_t hg, int64_t d_qk, + int64_t d_v, const int32_t *cu_seqlens_q_padded, const int32_t *cu_seqlens_kv_padded, DType offset_dtype, void *offsets_q, void *offsets_k, - void *offsets_v, void *offsets_o); + void *offsets_v, void *offsets_o, void *offsets_s); DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads, int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv, int64_t head_dim_qk, int64_t head_dim_v); +size_t get_max_batch_size(size_t batch_size); +size_t get_max_tokens(size_t num_tokens); } // namespace fused_attn cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h new file mode 100644 index 0000000000..17ecca5ff0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -0,0 +1,201 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ + +#include +#include +#include + +#include + +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" + +#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 + +namespace transformer_engine { + +/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. + * This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option. + * + * \return True if Userbuffers is built with MPI + */ +bool ubuf_built_with_mpi(); + +enum class CommOverlapType { RS = 0, AG = 1 }; + +enum class CommOverlapAlgo { + BULK_OVERLAP_AG = 0, + BULK_OVERLAP_RS = 1, + SPLIT_PIPELINED_AG_P2P = 2, + SPLIT_PIPELINED_RS = 3, + SPLIT_PIPELINED_RS_P2P = 4, + ATOMIC_GEMM_RS = 5, + ATOMIC_GEMM_AG_P2P = 6, + ATOMIC_GEMM_RS_P2P = 7 +}; + +class CommOverlapCore { + protected: + static inline communicator *_ub_comm{nullptr}; + static inline bool _comm_created{false}; + + int _rank; + int _tp_id; + int _tp_size; + int _num_splits; + int _math_sms; + int _num_comm_sm; + int _cga_size; + int _use_ce; + int _ub_reg; + bool _atomic_gemm{false}; + bool _is_p2p{false}; + + TensorWrapper _ubuf; + TensorWrapper _counter; + float *_ubuf_scale_inv; + bool _ubuf_scale_inv_initialized{false}; + + std::vector _stream_compute; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + + public: + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, + int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool use_ce, bool atomic_gemm); + + virtual ~CommOverlapCore(); + + void set_ubuf_scale_inv(float *scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + bool is_atomic_gemm() { return _atomic_gemm; } + + bool is_p2p_overlap() { return _is_p2p; } + + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } +}; // CommOverlapCore + +class CommOverlapBase : public CommOverlapCore { + protected: + int _rs_kernel_type; + cudaStream_t _stream_comm; + cudaEvent_t _start_d2dcopy; + + public: + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + + virtual ~CommOverlapBase(); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, + TensorWrapper &rs_output, cudaStream_t stream_main); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main); +}; // CommOverlapBase + +class CommOverlapP2PBase : public CommOverlapCore { + protected: + bool _is_reduce_scatter{false}; + bool _use_multiatomic_ag{false}; + + int _next_rank; + int _prev_rank; + int _rank_round_tp; + int _aggregate; + int _num_ubuf_chunks; + int _self_chunk_id; + + std::vector _ubufs; + + cudaStream_t _stream_send; + cudaStream_t _stream_recv; + cudaEvent_t _stop_send, _stop_recv; + + public: + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, + bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + + virtual ~CommOverlapP2PBase(); + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main); + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); +}; // CommOverlapP2PBase + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 191fc40ead..d302518235 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType */ void nvte_destroy_tensor(NVTETensor tensor); -/*! \brief Get a tensor's data type. +/*! \brief Get a raw pointer to the tensor's data. * * \param[in] tensor Tensor. * - * \return A data type of the input tensor. + * \return A raw pointer to tensor's data. */ -NVTEDType nvte_tensor_type(const NVTETensor tensor); +void *nvte_tensor_data(const NVTETensor tensor); /*! \brief Get a tensor's data shape. * @@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor); */ NVTEShape nvte_tensor_shape(const NVTETensor tensor); -/*! \brief Get a raw pointer to the tensor's data. +/*! \brief Get a tensor's number of dimensions. * * \param[in] tensor Tensor. * - * \return A raw pointer to tensor's data. + * \return Number of tensor dimensions. */ -void *nvte_tensor_data(const NVTETensor tensor); +size_t nvte_tensor_ndims(const NVTETensor tensor); + +/*! \brief Get the size of a specific tensor dimension. + * + * \param[in] tensor Tensor. + * \param[in] size_t Dimension index. + * + * \return Size of the tensor at the specified dimension. + */ +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); + +/*! \brief Get a tensor's total number of elements. + * + * \param[in] tensor Tensor. + * + * \return Number of elements in the tensor. + */ +size_t nvte_tensor_numel(const NVTETensor tensor); + +/*! \brief Get the byte size for the tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return Byte size of the tensor's data type. + */ +size_t nvte_tensor_element_size(const NVTETensor tensor); + +/*! \brief Get a tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return A data type of the input tensor. + */ +NVTEDType nvte_tensor_type(const NVTETensor tensor); /*! \brief Get a pointer to the tensor's amax data. * @@ -265,6 +298,56 @@ class TensorWrapper { return nvte_tensor_shape(tensor_); } + /*! \brief Get the size of this TensorWrapper in the given dimension. + * + * \param[in] size_t Dimension index. + * + * \return Size of this TensorWrapper in given dimension. + */ + size_t size(const size_t dim) const { + if (tensor_ == nullptr) return 0; + return nvte_tensor_size(tensor_, dim); + } + + /*! \brief Get the number of dimensions for this TensorWrapper. + * + * \return Number of dimensions for this TensorWrapper. + */ + size_t ndim() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_ndims(tensor_); + } + + /*! \brief Get the number of allocated elements in the tensor. This will return 0 for tensors + * with nullptr data even if the TensorWrapper has a non-zero shape. + * + * + * \return Number of elements in the tensor. + */ + size_t numel() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_); + } + + /*! \brief Get the tensor's element size in bytes. + * + * \return Element size in bytes. + */ + size_t element_size() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_element_size(tensor_); + } + + /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr + * data even if the TensorWrapper has a non-zero shape and valid dtype. + * + * \return Total tensor size in bytes. + */ + size_t bytes() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); + } + /*! \brief Get the data type of this TensorWrapper. * * \return Data type of this TensorWrapper. @@ -317,6 +400,6 @@ class TensorWrapper { } // namespace transformer_engine -#endif +#endif // __cplusplus #endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 5cfab2f8cf..1a3b49f9fa 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { return ret; } +size_t nvte_tensor_ndim(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.data.shape.size(); +} + +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { + const auto &t = *reinterpret_cast(tensor); + NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); + return t.data.shape[dim]; +} + +size_t nvte_tensor_numel(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + size_t numel = 1; + for (auto size : t.data.shape) { + numel *= size; + } + return numel; +} + +size_t nvte_tensor_element_size(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return transformer_engine::typeToSize(t.data.dtype); +} + void *nvte_tensor_data(const NVTETensor tensor) { const auto &t = *reinterpret_cast(tensor); return t.data.dptr; diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 5728ef557a..8d2e852988 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -12,6 +12,7 @@ #include "../common.h" #include "../util/cuda_driver.h" #include "../util/system.h" +#include "common/util/cuda_runtime.h" namespace transformer_engine { @@ -80,6 +81,31 @@ int sm_count(int device_id) { return cache[device_id]; } +bool supports_multicast(int device_id) { +#if CUDART_VERSION >= 12010 + // NOTE: This needs to be guarded at compile time because the + // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions. + static std::vector cache(num_devices(), false); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + CUdevice cudev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); + int result; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); + cache[device_id] = static_cast(result); + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +#else + return false; +#endif +} + const std::string &include_directory(bool required) { static std::string path; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index b6b4c41610..ea1ba84772 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -38,6 +38,14 @@ int sm_arch(int device_id = -1); */ int sm_count(int device_id = -1); +/* \brief CUDA Multicast support status for device + * + * \param[in] device_id CUDA device (default is current device) + * + * \return CUDA multicast support flag + */ +bool supports_multicast(int device_id = -1); + /* \brief Path to CUDA Toolkit headers * * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h new file mode 100644 index 0000000000..432ac815ec --- /dev/null +++ b/transformer_engine/common/util/pybind_helper.h @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ + +#include +#include +#include +#include + +#include "cuda_runtime.h" + +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType") \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType") \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo") \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); + +#endif diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 8438fa27ce..b3b11bb9dd 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -242,73 +242,16 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen -def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool): - match tensor_format: - case QKVFormat.SBHD: - seq_dim = 0 - case QKVFormat.BSHD: - seq_dim = 1 - case _: - raise ValueError(f"{tensor_format=} is not supported for causal load balancing.") - - if cp_size == 1: - return tensor - - if cp_size % 2 != 0: - raise ValueError(f"{cp_size=} must be a multiple of 2.") - - # Need to ensure we have 2 pairs to swap for balancing between cp ranks - if tensor.shape[seq_dim] % (cp_size * 2) != 0: - raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") - - # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] - # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] - ori_tensor_shape = tensor.shape - tensor = tensor.reshape( - ( - *ori_tensor_shape[:seq_dim], - 2 * cp_size, - ori_tensor_shape[seq_dim] // (2 * cp_size), - *ori_tensor_shape[seq_dim + 1 :], - ) - ) - - parts = [] - if not inverse: - for cp_rank in range(cp_size): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) - else: - for cp_rank in range(cp_size // 2): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - base = 4 * cp_rank - index = jnp.array([base, base + 2]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) - for cp_rank in range(cp_size // 2): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - base = 2 * cp_size - 1 - 4 * cp_rank - index = jnp.array([base, base - 2]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) - - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] - combined = jnp.stack(parts, axis=seq_dim) - - return combined.reshape(ori_tensor_shape) - - def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): """Reorders a tensor for load balancing the compute of causal attention.""" - return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False) + seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 + return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False) def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): """Inverse operation of `reorder_causal_load_balancing`.""" - return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True) + seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 + return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) def fused_attn( diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 47483c67ea..44b396ad55 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -383,37 +383,43 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum): assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - hidden_size = ir_x_shape[-1] - batch_shape = ir_x_shape[:-2] - batch_size = reduce(operator.mul, batch_shape) - out_shape = batch_shape + [hidden_size] - out_types = [ - ir.RankedTensorType.get(out_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_common_descriptor( - (batch_size, hidden_size), - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - act_enum, - ) + if is_ffi_enabled(): + name = "te_act_lu_fp8_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})( + ctx, x, amax, scale, scale_inv, act_enum=act_enum + ) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape - out = custom_caller( - ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} - ) + hidden_size = ir_x_shape[-1] + batch_shape = ir_x_shape[:-2] + batch_size = reduce(operator.mul, batch_shape) + out_shape = batch_shape + [hidden_size] + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_common_descriptor( + (batch_size, hidden_size), + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + act_enum, + ) + + out = custom_caller( + ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} + ) return out diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 54a5327f08..b0ea51b8b0 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -15,6 +15,7 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import ( @@ -33,6 +34,7 @@ te_dtype_to_jax_dtype, get_padded_spec, get_cudnn_version, + is_ffi_enabled, ) from ..sharding import ( global_mesh_resource, @@ -275,7 +277,16 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) + # cuDNN 9.6 reduces the required softmax shape + if get_cudnn_version() >= (9, 6, 0): + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) + else: + softmax_shape = ( + *batch_shape, + attn_heads, + q_max_seqlen, + config.max_segments_per_seq, + ) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") @@ -352,14 +363,6 @@ def lowering( """ Fused attention fwd lowering rules """ - operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( @@ -376,31 +379,82 @@ def lowering( wkspace_aval = ctx.avals_out[-1] - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, - bias_batch, - q_max_seqlen, - kv_max_seqlen, - attn_heads, - num_gqa_groups, - bias_heads, - head_dim, - config.max_segments_per_seq, - wkspace_aval.size, - config.scaling_factor, - config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, - jax_dtype_to_te_dtype(q_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - config.is_training, - not FusedAttnHelper.is_non_deterministic_allowed(), - config.window_size[0], - config.window_size[1], - ) + if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI")): + name = "te_fused_attn_forward_ffi" + out = ffi.ffi_lowering(name)( + ctx, + q, + k, + v, + bias, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + seed, + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + head_dim=head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type), + mask_type=int(config.attn_mask_type), + qkv_layout=int(config.qkv_layout), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=config.window_size[0], + window_size_right=config.window_size[1], + ) + else: + operands = [ + q, + k, + v, + bias, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + seed, + ] + operand_shapes = map(lambda x: x.type.shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + config.max_segments_per_seq, + wkspace_aval.size, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + config.is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), + config.window_size[0], + config.window_size[1], + ) - out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) return out @@ -911,6 +965,58 @@ def sharded_impl( register_primitive(FusedAttnBwdPrimitive) +def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): + """Reorders a tensor for load balancing the compute of causal attention.""" + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + ori_tensor_shape = tensor.shape + tensor = tensor.reshape( + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ) + ) + + parts = [] + if not to_contiguous: + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + else: + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return combined.reshape(ori_tensor_shape) + + @dataclass(frozen=True) class _FusedAttnCPWithAllGatherHelper: """Helper class to assist with running the all-gather strategy for CP attention.""" @@ -954,13 +1060,32 @@ def get_adjusted_mask(self): return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type + def get_step_config(self) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=self.get_adjusted_mask(), + qkv_layout=self.config.qkv_layout, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.config.max_segments_per_seq, + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis=self.config.cp_axis, + ) + def all_gather_kv(self, k, v): """Performs a all-gather of k and v over context parallel ranks.""" def ag(x): - return lax_paral_op( + x = lax_paral_op( x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True ) + if self.config.context_parallel_load_balanced: + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True) + return x match self.config.qkv_layout: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: @@ -974,6 +1099,10 @@ def reduce_scatter_dkv(self, dk, dv): """Performs a reduce-scatter of dk and dv over context parallel ranks.""" def rs(x): + if self.config.context_parallel_load_balanced: + cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) + x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False) + return lax_paral_op( x, lax.psum_scatter, @@ -1078,7 +1207,6 @@ def partition(config, mesh, arg_infos, result_infos): out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): - cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) @@ -1120,7 +1248,7 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): q_seq_offsets, k_seq_offsets, seed, - config=config, + config=helper.get_step_config(), ) results.append((output, softmax_aux, rng_state)) @@ -1237,7 +1365,7 @@ def _cross_attn_bwd( kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, - config=config, + config=helper.get_step_config(), ) # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 58b8db4c88..d3df614ac9 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -162,7 +162,7 @@ def is_ffi_enabled(): """ Helper function checking if XLA Custom Call with FFI is enabled """ - is_supported = jax_version_meet_requirement("0.4.31") + is_supported = jax_version_meet_requirement("0.4.35") # New APIs with FFI are enabled by default is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 48bf4d969a..062bbbf0fb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -9,6 +9,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType @@ -20,6 +21,7 @@ check_valid_batch_dims, jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, + is_ffi_enabled, ) from ..sharding import all_reduce_max_along_all_axes_except_PP @@ -84,30 +86,36 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - out_types = [ - ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_common_descriptor( - ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype) - ) + if is_ffi_enabled(): + name = "te_quantize_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})( + ctx, x, amax, scale, scale_inv + ) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + out_types = [ + ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_common_descriptor( + ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype) + ) - out = custom_caller( - CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} - ) + out = custom_caller( + CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} + ) return out diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 963d7f09e8..94585bc3e7 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -102,32 +102,36 @@ def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary): jnp.float8_e5m2, ] - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype) - if static_axis_boundary >= 0: - for i in range(static_axis_boundary + 1): - assert ir_x_shape[i] == 1 + if is_ffi_enabled(): + name = "te_transpose_ffi" + out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype) + if static_axis_boundary >= 0: + for i in range(static_axis_boundary + 1): + assert ir_x_shape[i] == 1 - transposed_x_shape = multidim_transpose( - ir_x_shape, static_axis_boundary, transpose_axis_boundary - ) + transposed_x_shape = multidim_transpose( + ir_x_shape, static_axis_boundary, transpose_axis_boundary + ) - out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)] - operands = [x] - operand_shapes = [ir_x_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)] + operands = [x] + operand_shapes = [ir_x_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - te_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - contracted_x_shape = ( - reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), - reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), - ) - opaque = transformer_engine_jax.pack_common_descriptor( - contracted_x_shape, te_dtype, te_dtype - ) + te_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + contracted_x_shape = ( + reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), + ) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, te_dtype, te_dtype + ) - out = custom_caller(TransposePrimitive.name, args, opaque, False) + out = custom_caller(TransposePrimitive.name, args, opaque, False) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c233177e28..e20bcf89ee 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -151,6 +151,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler); + void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -172,6 +174,8 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler); + XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler); pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -195,6 +199,8 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler); + pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, @@ -202,6 +208,8 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler); + void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -212,6 +220,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler); + void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); // Softmax @@ -253,6 +263,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); + pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 2baba48acf..f0fd58bfc5 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -110,7 +110,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type outp auto *output = output_buf->untyped_data(); auto input_dims = input_buf.dimensions(); - auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>()); + auto m = product(input_dims, 0, input_dims.size() - 2); auto n = input_dims.back(); auto act_len = input_dims.end()[-2]; auto act_type = static_cast(act_enum); @@ -153,6 +153,51 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op act_enum, act_len); } +Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, + Result_Type amax_out_buf, int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *output = output_buf->untyped_data(); + float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive."); + + if (!use_fp8(out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + + auto input_dims = input_buf.dimensions(); + auto m = product(input_dims, 0, input_dims.size() - 2); + auto n = input_dims.back(); + auto act_len = input_dims.end()[-2]; + auto act_type = static_cast(act_enum); + + ActLuImpl(input, m, n, in_dtype, out_dtype, scale, stream, scale_inv, amax_out, output, act_type, + act_len); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuFP8Handler, ActLuFP8FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // amax_out + .Attr("act_enum"), + FFI_CudaGraph_Traits); + void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; auto *act_input = buffers[1]; @@ -219,8 +264,7 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act auto *output = output_buf->untyped_data(); auto act_input_dims = act_input_buf.dimensions(); - auto m = - std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>()); + auto m = product(act_input_dims, 0, act_input_dims.size() - 2); auto n = act_input_dims.back(); auto act_len = act_input_dims.end()[-2]; diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 265ac27218..3679b46ee5 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -30,18 +30,13 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812 - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359 */ -void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, - const CustomCallFusedAttnDescriptor *desc, +void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t bias_batch, const size_t attn_heads, + const size_t bias_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr) { - auto input_batch = desc->input_batch; - auto bias_batch = desc->bias_batch; - auto attn_heads = desc->attn_heads; - auto bias_heads = desc->bias_heads; - auto q_max_seqlen = desc->q_max_seqlen; - auto kv_max_seqlen = desc->kv_max_seqlen; - // all backends need softmax but expect different shapes/dtypes // start with the max512 sequence length softmax shape/dtype and correct later tensor_pack->size = 1; @@ -49,7 +44,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, softmax_aux->data.dptr = softmax_buf; softmax_aux->data.shape = std::vector{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen}; - softmax_aux->data.dtype = desc->dtype; + softmax_aux->data.dtype = dtype; // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { @@ -69,7 +64,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, bias_aux->data.dptr = bias_buf; bias_aux->data.shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - bias_aux->data.dtype = desc->dtype; + bias_aux->data.dtype = dtype; } } } @@ -82,22 +77,25 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()? */ -void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, - const CustomCallFusedAttnDescriptor *desc, +void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t bias_batch, const size_t attn_heads, + const size_t bias_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, NVTE_Fused_Attn_Backend backend, void *softmax_buf, void *rng_state_buf, void *bias_buf) { // Backward calls put everything into the tensor pack for every backend // so we set dummy bias_type and backend choices here to follow the correct code path auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, softmax_buf, - rng_state_buf, bias_buf); + PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads, + q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type, + dummy_backend, softmax_buf, rng_state_buf, bias_buf); // correct softmax shape for max512 sequence length kernel if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); - softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} - softmax_aux->data.dtype = desc->dtype; + softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} + softmax_aux->data.dtype = dtype; } } @@ -187,47 +185,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } -void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - auto qkv_layout = descriptor.qkv_layout; +static void FusedAttnForwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, + void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, + size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, + size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ - void *bias = buffers[3]; - void *q_cu_seqlens = buffers[4]; - void *kv_cu_seqlens = buffers[5]; - void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; - void *seed = buffers[8]; - - /* Output buffer from XLA */ - void *output = buffers[9]; - void *softmax_aux = buffers[10]; - void *rng_state = buffers[11]; - void *workspace = buffers[12]; - - /* Descriptor */ - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - auto dtype = descriptor.dtype; - auto is_training = descriptor.is_training; - auto max_segments_per_seq = descriptor.max_segments_per_seq; - auto window_size_left = descriptor.window_size_left; - auto window_size_right = descriptor.window_size_right; - /* Input tensors */ auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -250,8 +218,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); num_segments = runtime_num_segments_q; } - cudaMemsetAsync(output, 0, - input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream); + auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; + cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); } auto q_cu_seqlens_tensor = @@ -279,32 +247,30 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s /* Auxiliary tensors (to be propagated to the backward pass later) */ NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, - softmax_aux); + PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads, + bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type, + backend, softmax_aux); /* cuDNN workspace */ - auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, - descriptor.wkspace_dtype); + auto workspace_tensor = + TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); - /* Call the underly NVTE API */ + /* Call the underlying NVTE API */ auto layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv = buffers[0]; auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - workspace_tensor.data(), stream); + auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); + nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, is_training, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv = buffers[1]; auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv_tensor = TensorWrapper(k, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), @@ -312,14 +278,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k = buffers[1]; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v = buffers[2]; auto v_shape = k_shape; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, @@ -335,6 +298,109 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s nvte_tensor_pack_destroy(&aux_output_tensors); } +void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD; + + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ + void *bias = buffers[3]; + void *q_cu_seqlens = buffers[4]; + void *kv_cu_seqlens = buffers[5]; + void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; + void *seed = buffers[8]; + + /* Output buffer from XLA */ + void *output = buffers[9]; + void *softmax_aux = buffers[10]; + void *rng_state = buffers[11]; + void *workspace = buffers[12]; + + FusedAttnForwardImpl( + stream, buffers[0], buffers[1], buffers[2], bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, + k_seq_offsets, seed, output, softmax_aux, rng_state, workspace, descriptor.input_batch, + descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, + descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, + descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor, + descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type, + descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training, + descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); +} + +Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Buffer_Type seed_buf, Result_Type output_buf, + Result_Type softmax_aux_buf, Result_Type rng_state_buf, + Result_Type workspace_buf, Dictionary attrs) { + /* Descriptor data type conversion */ + size_t input_batch = get_attr_value(attrs, "input_batch"); + size_t bias_batch = get_attr_value(attrs, "bias_batch"); + size_t q_max_seqlen = get_attr_value(attrs, "q_max_seqlen"); + size_t kv_max_seqlen = get_attr_value(attrs, "kv_max_seqlen"); + size_t attn_heads = get_attr_value(attrs, "attn_heads"); + size_t num_gqa_groups = get_attr_value(attrs, "num_gqa_groups"); + size_t bias_heads = get_attr_value(attrs, "bias_heads"); + size_t head_dim = get_attr_value(attrs, "head_dim"); + size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); + auto window_size_left = get_attr_value(attrs, "window_size_left"); + auto window_size_right = get_attr_value(attrs, "window_size_right"); + + float scaling_factor = get_attr_value(attrs, "scaling_factor"); + float dropout_probability = get_attr_value(attrs, "dropout_probability"); + + NVTE_Bias_Type bias_type = + static_cast(get_attr_value(attrs, "bias_type")); + NVTE_Mask_Type mask_type = + static_cast(get_attr_value(attrs, "mask_type")); + NVTE_QKV_Layout qkv_layout = + static_cast(get_attr_value(attrs, "qkv_layout")); + + bool is_training = get_attr_value(attrs, "is_training"); + bool deterministic = get_attr_value(attrs, "deterministic"); + + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; + + size_t wkspace_size = product(workspace_buf->dimensions()); + DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); + DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); + + FusedAttnForwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(), + is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(), + output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), + workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, + attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, + scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, + is_training, deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // bias + .Arg() // q_cu_seqlens + .Arg() // kv_cu_seqlens + .Arg() // q_seq_offsets + .Arg() // k_seq_offsets + .Arg() // seed_buf + .Ret() // output + .Ret() // softmax_aux + .Ret() // rng_state + .Ret() // workspace + .Attrs(), + FFI_CudaGraph_Traits); + pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, @@ -523,8 +589,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); - PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, - rng_state, bias); + PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, + bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, + softmax_aux, rng_state, bias); /* cuDNN workspace */ auto wkspace_size = std::vector{descriptor.wkspace_size}; @@ -540,7 +607,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dqkv = buffers[12]; auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dqkv, 0, product(qkv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dqkv, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 @@ -562,8 +629,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dkv = buffers[13]; auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dkv, 0, product(kv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dkv, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -591,9 +658,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dv = buffers[14]; auto dv_tensor = TensorWrapper(dv, v_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, product(k_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dv, 0, product(v_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::product(k_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dv, 0, transformer_engine::product(v_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index 19fd50cbd1..8b627aad35 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -7,20 +7,27 @@ #include -#include "common/util/logging.h" - namespace transformer_engine { namespace jax { // For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { switch (type) { - case xla::ffi::DataType::F16: - return DType::kFloat16; + case xla::ffi::DataType::U8: + return DType::kByte; + break; + case xla::ffi::DataType::S32: + return DType::kInt32; + break; + case xla::ffi::DataType::S64: + return DType::kInt64; break; case xla::ffi::DataType::F32: return DType::kFloat32; break; + case xla::ffi::DataType::F16: + return DType::kFloat16; + break; case xla::ffi::DataType::BF16: return DType::kBFloat16; break; diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 729e8e60e3..48b6f69ea8 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -9,6 +9,8 @@ #include +#include "common/util/logging.h" + namespace transformer_engine { namespace jax { @@ -17,10 +19,63 @@ using Result_Type = xla::ffi::Result; using Error_Type = xla::ffi::Error; using FFI = xla::ffi::Ffi; using FFI_Stream_Type = xla::ffi::PlatformStream; +using Dictionary = xla::ffi::Dictionary; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; -DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type); +DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); + Error_Type ffi_with_cuda_error_check(); +// source_location is not available in C++17, so we implement it ourselves +#if defined(__GNUC__) || defined(__clang__) +#define CURRENT_FILE __builtin_FILE() +#define CURRENT_LINE __builtin_LINE() +#define CURRENT_FUNCTION __builtin_FUNCTION() +#else +#define CURRENT_FILE __FILE__ +#define CURRENT_LINE __LINE__ +#define CURRENT_FUNCTION __func__ +#endif + +class source_location { + public: + static source_location current(const char* file = CURRENT_FILE, int line = CURRENT_LINE, + const char* function = CURRENT_FUNCTION) { + return source_location(file, line, function); + } + + constexpr const char* file_name() const { return file_; } + constexpr int line() const { return line_; } + constexpr const char* function_name() const { return function_; } + + private: + constexpr source_location(const char* file, int line, const char* function) + : file_(file), line_(line), function_(function) {} + + const char* file_; + int line_; + const char* function_; +}; + +template +T get_attr_value(Dictionary& attrs, std::string attr_name, + const source_location& loc = source_location::current()) { + auto attr = attrs.get(attr_name); + if (attr.has_error()) { + NVTE_ERROR("Failure in getting attribute value of '", attr_name, "'\n", + "Called from: ", loc.file_name(), ":", loc.line(), "\n", + "In function: ", loc.function_name(), "\n", + "Please ensure the attribute name and datatype match between C++ and Python APIs."); + } + return attr.value(); +} + +inline size_t product(const xla::ffi::Span& data, size_t start_idx = 0, + size_t end_idx = 0) { + end_idx = (end_idx == 0) ? data.size() : end_idx; + return std::accumulate(data.begin() + start_idx, data.begin() + end_idx, size_t(1), + std::multiplies()); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index fb40400e62..4c6fcb6394 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -237,6 +237,72 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque sm_margin, stream); } +Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, + Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf, + Buffer_Type scale_inv_buf, Result_Type output_buf, + Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type amax_out_buf, Result_Type wkspace_buf, + Result_Type barrier_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); + auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); + auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type()); + auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type()); + + auto *input = x_buf.untyped_data(); + auto *weight = gamma_buf.untyped_data(); + auto *bias = beta_buf.untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *scale = reinterpret_cast(scale_buf.untyped_data()); + auto *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + auto *output = output_buf->untyped_data(); + auto *mu = mu_buf->untyped_data(); + auto *rsigma = rsigma_buf->untyped_data(); + auto *amax_out = amax_out_buf->untyped_data(); + auto *workspace = wkspace_buf->untyped_data(); + auto *barrier = barrier_buf->untyped_data(); + NVTE_CHECK(amax_out == amax, + "amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive"); + + auto x_size = product(x_buf.dimensions()); + auto gamma_size = product(gamma_buf.dimensions()); + auto hidden_size = gamma_size; + auto batch_size = x_size / gamma_size; + + auto wkspace_size = product(wkspace_buf->dimensions()); + auto barrier_size = product(barrier_buf->dimensions()); + + float eps = static_cast(eps_); + int sm_margin = static_cast(sm_margin_); + auto out_dtype = DType::kFloat8E4M3; + + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, + eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, + wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, + sm_margin, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // gamma + .Arg() // beta + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // mu + .Ret() // rsigma + .Ret() // amax_out + .Ret() // wkspace + .Ret() // barrier + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; auto *weight = buffers[1]; @@ -310,6 +376,79 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, dbeta_part_dtype, sm_margin, stream); } +Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, + Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, + Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf, + Result_Type wkspace_buf, Result_Type barrier_buf, + Result_Type dgamma_part_buf, Result_Type dbeta_part_buf, + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); + auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); + auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type()); + auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type()); + auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype(dgamma_part_buf->element_type()); + auto dbeta_part_dtype = convert_ffi_datatype_to_te_dtype(dbeta_part_buf->element_type()); + + auto *ograd = dz_buf.untyped_data(); + auto *mu = mu_buf.untyped_data(); + auto *rsigma = rsigma_buf.untyped_data(); + auto *input = x_buf.untyped_data(); + auto *weight = gamma_buf.untyped_data(); + auto *xgrad = xgrad_buf->untyped_data(); + auto *wgrad = wgrad_buf->untyped_data(); + auto *dbeta = dbeta_buf->untyped_data(); + auto *workspace = wkspace_buf->untyped_data(); + auto *barrier = barrier_buf->untyped_data(); + auto *dgamma_part = dgamma_part_buf->untyped_data(); + auto *dbeta_part = dbeta_part_buf->untyped_data(); + + auto x_size = product(x_buf.dimensions()); + auto gamma_size = product(gamma_buf.dimensions()); + auto hidden_size = gamma_size; + auto batch_size = x_size / gamma_size; + + auto wkspace_size = product(wkspace_buf->dimensions()); + auto barrier_size = product(barrier_buf->dimensions()); + + auto dgamma_part_dims = dgamma_part_buf->dimensions(); + auto dbeta_part_dims = dbeta_part_buf->dimensions(); + std::vector dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end()); + std::vector dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end()); + Shape dgamma_part_shape, dbeta_part_shape; + dgamma_part_shape.from_vector(dgamma_parts_dims_vector); + dbeta_part_shape.from_vector(dbeta_parts_dims_vector); + + float eps = static_cast(eps_); + int sm_margin = static_cast(sm_margin_); + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, + dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, + w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, + rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, + dbeta_part_dtype, sm_margin, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // dz + .Arg() // x + .Arg() // mu + .Arg() // rsigma + .Arg() // gamma + .Ret() // xgrad + .Ret() // wgrad + .Ret() // dbeta + .Ret() // wkspace + .Ret() // barrier + .Ret() // dgamma_part + .Ret() // dbeta_part + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; auto *weight = buffers[1]; diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 14f449a76b..f134229fc1 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -52,9 +52,15 @@ pybind11::dict Registrations() { dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); + dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); + dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); + dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); + dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); + dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); return dict; } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index ba376c6238..5e33098eab 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -6,6 +6,7 @@ #include "extensions.h" #include "transformer_engine/cast.h" +#include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { @@ -27,6 +28,41 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); } +Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, + Result_Type amax_out_buf) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *scale = reinterpret_cast(scale_buf.untyped_data()); + auto *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *output = output_buf->untyped_data(); + auto *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive."); + + auto input_dims = input_buf.dimensions(); + std::vector shape(input_dims.begin(), input_dims.end()); + auto input_tensor = TensorWrapper(input, shape, in_dtype); + auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv); + + nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(QuantizeHandler, QuantizeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret(), // amax_out + FFI_CudaGraph_Traits); + void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; auto *amax = reinterpret_cast(buffers[1]); diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 1d1957e0bf..b63a138ad8 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -36,6 +36,37 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o TransposeImpl(input, rows, cols, dtype, stream, output); } +Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, + int64_t transpose_axis) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + void *input = input_buf.untyped_data(); + void *output = output_buf->untyped_data(); + + auto input_dims = input_buf.dimensions(); + if (transpose_axis < 0) transpose_axis += input_dims.size(); + auto m = product(input_dims, 0, transpose_axis); + auto n = product(input_dims, transpose_axis, input_dims.size()); + + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{n, m}; + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, out_dtype); + + nvte_transpose(input_tensor.data(), output_tensor.data(), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(TransposeHandler, TransposeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("transpose_axis"), + FFI_CudaGraph_Traits); + void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; float *amax = reinterpret_cast(buffers[1]); @@ -82,7 +113,7 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input_cast = input_cast_buf->untyped_data(); auto *input_cast_trans = input_cast_trans_buf->untyped_data(); float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); - assert(amax == amax_out); + NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive."); if (!use_fp8(out_dtype)) { scale = nullptr; @@ -92,10 +123,8 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_dims = input_buf.dimensions(); if (transpose_axis < 0) transpose_axis += input_dims.size(); - auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1, - std::multiplies<>()); - auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1, - std::multiplies<>()); + auto m = product(input_dims, 0, transpose_axis); + auto n = product(input_dims, transpose_axis, input_dims.size()); auto input_shape = std::vector{m, n}; auto input_trans_shape = std::vector{n, m}; diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 904d979b8e..583cd0f47a 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -603,14 +603,14 @@ void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_p auto state_index = gen_cuda->GetStateIndex(); auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams ¶ms) { + rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { // ensure the generator use correct state index gen_cuda->SetStateIndex(state_index); auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); params.As>(1) = seed_offset; }; - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void *functionPtr = reinterpret_cast(&set_rng_state); cudaFunction_t cudaFunc; @@ -1016,14 +1016,14 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p #if PADDLE_VERSION > 261 auto state_index = gen_cuda->GetStateIndex(); auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams ¶ms) { + rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { // ensure the generator use correct state index gen_cuda->SetStateIndex(state_index); auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); params.As>(1) = seed_offset; }; - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void *functionPtr = reinterpret_cast(&set_rng_state); cudaFunction_t cudaFunc; @@ -1383,7 +1383,7 @@ void amax_and_scale_update_inplace_legacy( const int *current_step_id_ptr = reinterpret_cast(GetOptionalDataPtr(current_step_id_tensor)); auto parameterSetter = [current_step_id_ptr, - fwd_update](phi::backends::gpu::CUDAKernelParams ¶ms) { + fwd_update](phi::backends::gpu::gpuKernelParams ¶ms) { if (fwd_update) { int current_step_id = *current_step_id_ptr; params.As(7) = (current_step_id == 0); @@ -1397,7 +1397,7 @@ void amax_and_scale_update_inplace_legacy( float *scale_ptr = scale.data(); float *scale_inv_ptr = scale_inv.data(); - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void *functionPtr = reinterpret_cast(&UpdateFP8MetaKernel); cudaFunction_t cudaFunc; diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5f8357a01b..6b153fd3c1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -134,7 +134,7 @@ def _get_supported_versions(version_min, version_max): try: _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) except PackageNotFoundError: - if get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: + if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: fa_logger.debug( "flash-attn v2 is not installed. To use, please install it by" """ "pip install flash-attn".""", @@ -158,7 +158,9 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") - elif get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: + elif ( + torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN + ): fa_logger.warning( "Supported flash-attn versions are %s. Found flash-attn %s.", _get_supported_versions( @@ -183,7 +185,7 @@ def _get_supported_versions(version_min, version_max): try: _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper")) except PackageNotFoundError: - if get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN: + if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN: fa_logger.debug( "flash-attn v3 is not installed. To use, please install it by \n%s", _flash_attn_3_installation_steps, @@ -1727,17 +1729,20 @@ def forward( fused_attn_qkv_dtype = None fused_attn_backend = None amax_per_step = None + qkv_dtype = q.dtype + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype + is_input_fp8 = False + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if fp8: if use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data @@ -1776,7 +1781,7 @@ def forward( ) if not fp8: q_f16 = q - elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16 = q q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) @@ -1878,11 +1883,7 @@ def forward( batch_p2p_comm, ) - if ( - not fp8 - or fp8_meta["recipe"].fp8_mha - or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - ): + if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): kv_inputs[i % 2] = p2p_comm_buffers[i] else: # KV exchange is in BF16/FP16, cast received KV in each step @@ -2434,18 +2435,18 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] out_fp8 = None - out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) - if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_f16 = out.to(qkv_dtype) + if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=META_O, fp8_dtype=fp8_dtype_forward, - dtype=q_fp8.dtype, + dtype=qkv_dtype, ) else: out_ret = out_f16 @@ -2454,7 +2455,7 @@ def forward( q_save, kv_save, out_save = q, kv, out_fp8 fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() - elif fp8 and fp8_meta["recipe"].fp8_mha: + elif fp8 and is_input_fp8: q_fp8 = Float8Tensor( data=q, fp8_meta=fp8_meta, @@ -2511,6 +2512,8 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 return out_ret @staticmethod @@ -2593,7 +2596,7 @@ def backward(ctx, dout): dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) dkv_fp8_ = torch.empty_like(dkv_fp8) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv dout = dout._data @@ -2615,7 +2618,7 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8_meta is not None and ctx.is_input_fp8: q, kv = [x.from_float8(x.dtype) for x in [q, kv]] if cp_size_a2a == 1: dout = dout.from_float8(dout_dtype) @@ -2651,7 +2654,7 @@ def backward(ctx, dout): ctx.cp_stream, True, ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: dout = cast_from_fp8( dout, None, @@ -3258,7 +3261,7 @@ def backward(ctx, dout): dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv = dkv_ - if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8 and ctx.is_input_fp8: dq, dkv = [ cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) for x in [dq, dkv] @@ -3281,7 +3284,7 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8 and ctx.is_input_fp8: dq, dk, dv = [ Float8Tensor( data=x, @@ -3850,19 +3853,22 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" + qkv_dtype = q.dtype fused_attn_backend = None fused_attn_qkv_dtype = None + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype + is_input_fp8 = False + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if fp8: if use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data @@ -3898,7 +3904,7 @@ def forward( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) - if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v q, k, v = [ cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) @@ -3963,14 +3969,14 @@ def forward( out = out.view(-1, batch_size, *out.shape[-2:]) if fp8: - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_fp8 = Float8Tensor( data=out, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=META_O, fp8_dtype=fp8_dtype_forward, - dtype=q_fp8.dtype, + dtype=qkv_dtype, ) out = out_fp8._data out_ret = out_fp8 @@ -3989,7 +3995,7 @@ def forward( if fp8: if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, k_save, v_save, out_save = q, k, v, out - elif fp8_meta["recipe"].fp8_mha: + elif is_input_fp8: q_fp8, k_fp8, v_fp8 = [ Float8Tensor( data=x, @@ -4041,6 +4047,8 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 return out_ret @staticmethod @@ -4062,6 +4070,7 @@ def backward(ctx, dout): fused_attn_backend = None fused_attn_dqkv_dtype = None fused_attn_qkv_dtype = None + dout_dtype = dout.dtype if ctx.fp8: if ctx.use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) @@ -4069,7 +4078,7 @@ def backward(ctx, dout): fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv dout_fp8 = dout @@ -4095,7 +4104,7 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8_meta is not None and ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] if ctx.use_fused_attention: @@ -4192,7 +4201,7 @@ def backward(ctx, dout): dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8: - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq, dk, dv = [ Float8Tensor( data=x, @@ -4200,7 +4209,7 @@ def backward(ctx, dout): fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, - dtype=dout_fp8.dtype, + dtype=dout_dtype, ) for x in [dq, dk, dv] ] @@ -4211,7 +4220,7 @@ def backward(ctx, dout): ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward, - TE_DType[dout_f16.dtype], + TE_DType[dout_dtype], ) for x in [dq, dk, dv] ] @@ -5432,11 +5441,12 @@ def convert_to_torch_float8(tensor, dtype): ) return out - if fp8_meta["recipe"].fp8_mha: - assert all( - isinstance(x, Float8Tensor) - for x in [query_layer, key_layer, value_layer] - ), "q/k/v must be Float8Tensors for FP8 MHA." + # "fp8_mha" decides outputs in fp8, while inputs are inferred from + # the real dtype + assert isinstance(key_layer, query_layer.__class__) and isinstance( + value_layer, query_layer.__class__ + ), "q, k, and v must have the same type." + if isinstance(query_layer, Float8Tensor): fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv else: query_layer, key_layer, value_layer = ( @@ -5578,6 +5588,7 @@ def forward( deterministic, ): # pylint: disable=missing-function-docstring + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -5968,6 +5979,7 @@ def forward( deterministic, ): # pylint: disable=missing-function-docstring + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -6422,6 +6434,7 @@ def forward( deterministic, ): # pylint: disable=missing-function-docstring + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -7671,6 +7684,60 @@ def forward( based on its internal logic. These optimizations trade memory for performance and should be used with care. + .. note:: + .. _cu_seqlens note: + + When training data has variable sequence lengths, users have two options. + + 1. Manipulate the data and pad all sequences to the same length. Use + :attr:`qkv_format` = {"bshd", "sbhd"} and + :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. + Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask` + (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide + the real sequence length information. For example, a batch of 3 sequences + [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative + sequence length tensors would be + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + + 2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and + :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. + Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`, + as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed + without any padding, and the sequence length tensors would be + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + + In certain use cases, a varying number of identifier tokens are inserted between + sequences. These tokens do not participate in the attention calculation. + :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified + in such cases to correctly identify the start and end of each sequence in a batch. + For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and + :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13] + for self-attention. + + .. note:: + .. _max_seqlen note: + + When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch. + :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of + :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will + infer them as such. + + When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and + :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch. + When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`. + This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this + overhead, users are recommended to obtain the maximum sequence lengths from the data loaders + and pass them in. + + - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch, + dynamic shapes need to be supported for tensor construction. FlashAttention and + UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static + to create graphs before performance heuristics analysis. To reduce the number of graphs created + per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size, + :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of + :attr:`query_layer`, "t" dimension of :attr:`key_layer`}. + Parameters ---------- query_layer : torch.Tensor @@ -7693,25 +7760,29 @@ def forward( cu_seqlens_q: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. + See :ref:`note` for more details. cu_seqlens_kv: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + See :ref:`note` for more details. cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, `cu_seqlens_q_padded = cu_seqlens_q`. + See :ref:`note` for more details. cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, `cu_seqlens_kv_padded = cu_seqlens_kv`. + See :ref:`note` for more details. max_seqlen_q: Optional[int], default = `None` Maximum sequence length in `query_layer`. - Calculated from `cu_seqlens_q` if not provided. + See :ref:`note` for more details. max_seqlen_kv: Optional[int], default = `None` Maximum sequence length in `key_layer` and `value_layer`. - Calculated from `cu_seqlens_kv` if not provided. + See :ref:`note` for more details. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right', 'arbitrary'}, default = `None`. Type of attention mask passed into @@ -7902,6 +7973,7 @@ def forward( assert ( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] @@ -7914,7 +7986,6 @@ def forward( else: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - batch_size = len(cu_seqlens_q) - 1 cp_size = 1 if isinstance(self.cp_group, dist_group_type): @@ -7929,10 +8000,12 @@ def forward( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" if qkv_format == "sbhd": - max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv batch_size = query_layer.shape[1] else: - max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv batch_size = query_layer.shape[0] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size @@ -7941,13 +8014,13 @@ def forward( assert all( seqlens_q <= max_seqlen_q ), """Sequence lengths indicated by cu_seqlens_q must be no greater than - the sequence dimention in 'query_layer'!""" + the sequence dimension in 'query_layer'!""" if cu_seqlens_kv is not None: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] assert all( seqlens_kv <= max_seqlen_kv ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than - the sequence dimention in 'key_layer' and 'value_layer'!""" + the sequence dimension in 'key_layer' and 'value_layer'!""" if cu_seqlens_q is None or cu_seqlens_kv is None: if "padding" in attn_mask_type: assert ( @@ -8433,6 +8506,8 @@ def __init__( self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_attention_heads = num_attention_heads self.return_bias = return_bias + self.cp_size = 1 + self.cp_rank = 0 kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) @@ -8651,6 +8726,21 @@ def set_context_parallel_group( across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink). """ + if isinstance(cp_group, dist_group_type): + self.cp_size = get_distributed_world_size(cp_group) + self.cp_rank = get_distributed_rank(cp_group) + elif isinstance(cp_group, list): + assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + assert ( + cp_comm_type == "a2a+p2p" + ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" + cp_size_a2a = get_distributed_world_size(cp_group[0]) + cp_rank_a2a = get_distributed_rank(cp_group[0]) + cp_size_p2p = get_distributed_world_size(cp_group[1]) + cp_rank_p2p = get_distributed_rank(cp_group[1]) + self.cp_size = cp_size_a2a * cp_size_p2p + self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a + # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: @@ -8985,8 +9075,24 @@ def forward( q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + query_layer = apply_rotary_pos_emb( + query_layer, + q_pos_emb, + self.qkv_format, + fused=True, + cu_seqlens=cu_seqlens_q, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) + key_layer = apply_rotary_pos_emb( + key_layer, + k_pos_emb, + self.qkv_format, + fused=True, + cu_seqlens=cu_seqlens_kv, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) # =========================== # Core attention computation diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index fd1eb4a810..932bb3cafa 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -45,8 +45,8 @@ def fp8_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, - ub_algo: tex.UbufOverlapAlgo = None, - ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None, + ub_algo: tex.CommOverlapAlgo = None, + ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, extra_output_tensor: torch.Tensor = None, ) -> torch.Tensor: """TN layout GEMM with fp8 inputs.""" @@ -107,7 +107,7 @@ def fp8_gemm( fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: assert ub is not None, "ub object is None!" - if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor @@ -115,11 +115,11 @@ def fp8_gemm( args = tuple( args + ( - 1, + tex.CommOverlapType.AG, extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor @@ -127,23 +127,23 @@ def fp8_gemm( args = tuple( args + ( - 0, + tex.CommOverlapType.RS, extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: fn = ub.atomic_gemm_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: fn = ub.split_overlap_rs assert ( extra_output_tensor is not None @@ -155,13 +155,13 @@ def fp8_gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS: fn = ub.atomic_gemm_overlap_rs assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" args = tuple( @@ -171,16 +171,13 @@ def fp8_gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P: fn = ub.atomic_gemm_overlap_rs_p2p assert ( extra_output_tensor is not None ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) - if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: - out = fn(*args) - else: - _ = fn(*args) + _ = fn(*args) return out, gelu_input @@ -198,8 +195,8 @@ def gemm( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_bias: bool = False, - ub_algo: tex.UbufOverlapAlgo = None, - ub: tex.UbufCommOverlap = None, + ub_algo: tex.CommOverlapAlgo = None, + ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, extra_output_tensor: torch.Tensor = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Non FP8 GEMM.""" @@ -270,19 +267,19 @@ def gemm( fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: assert ub is not None, "ub object is None!" - if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap - args = tuple(args + (1, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + args = tuple(args + (tex.CommOverlapType.AG, empty_tensor)) + elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap - args = tuple(args + (0, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + args = tuple(args + (tex.CommOverlapType.RS, empty_tensor)) + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: fn = ub.split_overlap_rs assert ( extra_output_tensor is not None @@ -294,7 +291,7 @@ def gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h deleted file mode 100644 index 3b4e126943..0000000000 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ /dev/null @@ -1,1303 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "common/common.h" -#include "common/util/cuda_driver.h" -#include "common/util/logging.h" -#include "common/util/system.h" -#include "extensions.h" -#include "userbuffers/userbuffers.h" - -#define HALF_BYTES 2 -#define UB_MAX_SM 32 - -using namespace torch::indexing; -using namespace std::placeholders; - -namespace ubuf { - -bool device_supports_multicast() { - int dev, supports_multicast; - CUdevice cudev; - - NVTE_CHECK_CUDA(cudaGetDevice(&dev)); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, dev); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &supports_multicast, - CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); - - return static_cast(supports_multicast); -} - -bool ubuf_built_with_mpi() { -#ifdef NVTE_UB_WITH_MPI - return true; -#else - return false; -#endif -} - -class UbufBootstrapCallbacks : torch::CustomClassHolder { - private: - bool initialized{false}; - bool backend_is_nccl{false}; - std::map pgs; - - public: - UbufBootstrapCallbacks() { -#ifndef NVTE_UB_WITH_MPI - NVTE_ERROR("Internal TE error: Dummy UbufBootstrapCallbacks init without NVTE_UB_WITH_MPI=1!"); -#endif - } // empty constructor for NVTE_UB_WITH_MPI=1 - - UbufBootstrapCallbacks(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group) { - pgs.insert({"world", world_group}); - c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); - backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - - NVTE_CHECK(intra_node_group->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - pgs.insert({"intra", intra_node_group}); - - initialized = true; - } - - ~UbufBootstrapCallbacks() { - for (auto &pg : pgs) pg.second = nullptr; - backend_is_nccl = false; - initialized = false; - } - - void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - char *group) { - NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", - "with valid process groups!"); - - auto localtensor = - torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; - auto globaltensor = - torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; - - std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; - std::vector localchunk = {localtmp}; - auto work = pgs[group]->allgather(globalchunks, localchunk); - work->wait(); - - if (backend_is_nccl) { - globaltensor.copy_(globaltmp.cpu()); - globaltmp = torch::Tensor(); - localtmp = torch::Tensor(); - } - } - - void ub_barrier(char *group) { - NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", - "with valid process groups!"); - auto work = pgs[group]->barrier(); - work->wait(); - } -}; - -enum class COMM_TYPE { RS = 0, AG = 1 }; - -enum class UBOverlapAlgo { - BULK_OVERLAP_AG = 0, - BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG_P2P = 2, - SPLIT_PIPELINED_RS = 3, - SPLIT_PIPELINED_RS_P2P = 4, - ATOMIC_GEMM_RS = 5, - ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 -}; - -struct UbufBase { - static inline communicator *_ub_comm{nullptr}; - static inline bool comm_created{false}; -}; -struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { - int _tp_id; - int _tp_size; - int _num_splits; - int _math_sms; - int _ub_reg; - void *_ubuf_ptr; - torch::Tensor _ubuf; - torch::Tensor output_tensor; - torch::Tensor _ubuf_scale_inv; - bool _ubuf_scale_inv_initialized; - torch::Tensor counter; - at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); - std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; - int _num_comm_sm; - int _cga_size; - int _use_ce; - bool _atomic_gemm; - - UbufCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, - int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm, - UbufBootstrapCallbacks &callbacks) { - // Initialize userbuf communicator - if (!comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else - create_communicator_grouped2( - &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); -#endif - comm_created = true; - } - _use_ce = 0; - _num_comm_sm = num_comm_sm; - _cga_size = comm_cga_size; - - // Allocate and register extra userbuffers - int ubuf_bytes = sample.numel() * sample.element_size(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - - if (_ub_comm->myrank == 0) { - printf("!!! [UB] Register UBuf %d\n", _ub_reg); - } - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { - cudaStream_t stream; - cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); - _stream_compute.push_back( - at::cuda::getStreamFromExternal(stream, stream_main.device_index())); - } - - _num_splits = num_splits; - _tp_size = tp_size; - _tp_id = (_ub_comm->myrank % _tp_size); - _ubuf_scale_inv_initialized = false; - - // Set the number of SMs for GEMM with margin - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; - _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); - - output_tensor = torch::Tensor(); - _atomic_gemm = atomic_gemm; - if (_atomic_gemm) { - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({num_splits * 2}, counter_options); - counter.index_put_({Slice(None, num_splits)}, 1); - } - // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_d2dcopy, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); - } - - ~UbufCommOverlap() { - cudaEventDestroy(_stop_comm); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_start_d2dcopy); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); - - if (comm_created) { -#ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); -#else - destroy_communicator(_ub_comm); -#endif - comm_created = false; - } - } - - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get the current userbuf offset - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type == COMM_TYPE::RS) { - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication: AG and RS - if (_comm_type == COMM_TYPE::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm); - } else if (_comm_type == COMM_TYPE::RS) { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - comm_elements *= 2; - float *scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - assert(rs_output.numel() == _ubuf.numel() / _tp_size); - assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); - assert(rs_output.element_size() == 2); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, scale_inv_ptr, _ub_reg, 0, - comm_elements, _ub_comm, - (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, - (cudaStream_t)_stream_comm); - } - } else { - NVTE_ERROR("Not supported communication type."); - } - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale, - D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize, - accumulate, use_split_accumulator, _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - - // Generate output tensor from userbuf data pointer - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - _ub_comm->sms = ori_sms; - - return {D, output_tensor}; - } // bulk_overlap - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions - int m = A.size(0); - int k = A.size(1); - int n = B.size(0); - int m_chunk = m / _num_splits; - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - torch::Tensor input_a = torch::from_blob(input_a_chunk_ptr, {m, k}, A.options()); - torch::Tensor output_d = torch::from_blob(output_buf_chunk_ptr, {n, m}, _ubuf.options()); - // torch::zeros({n, m}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_atomic_gemm(input_a, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_d, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, - counter); - - for (int i = 0; i < _num_splits; i++) { - const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, - _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _num_splits, &counter_ptr[i], _ub_comm, - (cudaStream_t)_stream_comm); - } - } else if (env_p != nullptr && env_p[0] == '2') { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, - counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, - m, _num_splits, counter_ptr, _ub_comm, - (cudaStream_t)_stream_comm); - } - break; - } else { - assert(_ubuf.element_size() != 1); - consumer(counter_ptr, i, (cudaStream_t)_stream_comm); - reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - - return; - } // split_overlap_rs - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, at::Tensor rs_output) { - // Get GEMM dimensions - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - int m = A.size(0); - int k = A.size(1); - int n = B.size(0); - int m_chunk = m / _num_splits; - int input_a_chunk_size = m_chunk * k; - int output_chunk_size = n * m_chunk; - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - if (gemm_overlap) { - torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - - for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, - m, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Last communication chunk with max SM - _ub_comm->sms = UB_MAX_SM; - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (_num_splits - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } - } else { - for (int i = 0; i < _num_splits; i++) { - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, - (cudaStream_t)_stream_compute[i % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk. Uses MAX_SM at the last chunk - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - } - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - - return; - } // split_overlap_rs - - void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; - } - - bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } - /* - ** Helper function to copy input to _ubuf - */ - void copy_input_to_ubuf(torch::Tensor input, int comm_type) { - char *ubuf_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type == COMM_TYPE::AG) { - if ((input.numel() * _tp_size) != _ubuf.numel() || - input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } else { - if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - } - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)_stream_comm)); - } - - torch::Tensor &get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); - if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - return output_tensor; - } - - bool is_atomic_gemm() { return _atomic_gemm; } - bool is_p2p_overlap() { return false; } -}; // UbufCommOverlap - -struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { - int _tp_id; - int _tp_size; - int _ub_reg, _ub_reg2; - int _next_rank, _prev_rank, _rank, _rank_round_tp; - int _aggregate2; - int _math_sms; - int _self_chunk_id; - void *_ubuf_ptr; - torch::Tensor _ubuf; - torch::Tensor counter; - torch::Tensor _ubuf_scale_inv; - bool _ubuf_scale_inv_initialized; - std::vector _ubufs; - at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); - at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); - std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv; - int _use_ce; - int _num_comm_sm; - int _cga_size; - bool _atomic_gemm; - - UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, - bool set_sm_margin, bool aggregate2, int num_max_streams, - bool is_reduce_scatter, bool atomic_gemm, bool use_ce, - UbufBootstrapCallbacks &callbacks) { - // Initialize userbuf communicator - if (!comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else - create_communicator_grouped2( - &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); -#endif - comm_created = true; - } - _use_ce = use_ce; - _num_comm_sm = num_comm_sm; - _cga_size = comm_cga_size; - - // Create workspace tensor with userbuffer - int ubuf_bytes = sample.numel() * sample.element_size(); - int ubuf_chunk_bytes = ubuf_bytes / tp_size; - int num_ubuf_chunks = tp_size; - if (is_reduce_scatter) { - // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk - // outputs for reduction at the end of the pipelining. - ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); - num_ubuf_chunks = static_cast(tp_size * 2 - 1); - } - - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob( - _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); - if (_ub_comm->myrank == 0) { - printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); - } - - // Create tensor chunks for easy management - char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); - for (int i = 0; i < num_ubuf_chunks; i++) { - auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, - sample.options()); - _ubufs.push_back(std::move(ubuf_chunk)); - ubuf_byte_ptr += ubuf_chunk_bytes; - } - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - for (int i = 0; i < std::min(num_max_streams, tp_size); i++) { - cudaStream_t stream; - cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); - _stream_compute.push_back( - at::cuda::getStreamFromExternal(stream, stream_main.device_index())); - } - - // Set the number of SMs for GEMM with margin - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; - _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); - - _tp_size = tp_size; - _aggregate2 = aggregate2; - - _rank = _ub_comm->myrank; - _tp_id = (_rank % _tp_size); - _rank_round_tp = (_rank / _tp_size) * _tp_size; - _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; - _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; - _ubuf_scale_inv_initialized = false; - - _atomic_gemm = atomic_gemm; - _self_chunk_id = _tp_id; - if (_atomic_gemm) { - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({_tp_size * 2}, counter_options); - counter.index_put_({Slice(None, _tp_size)}, 1); - - if (!is_reduce_scatter) { - const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (_rank == 0 && env_p != nullptr) { - if (env_p[0] == '1') { - _use_ce = 0; - _ub_comm->push = 1; - printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); - } - } - _self_chunk_id = 0; - counter.index_put_({_self_chunk_id}, 0); - } - } - - // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_send, 0); - cudaEventCreateWithFlags(&_stop_recv, 0); - } - - ~UbufP2PCommOverlap() { - cudaEventDestroy(_stop_recv); - cudaEventDestroy(_stop_send); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); - - if (comm_created) { -#ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); -#else - destroy_communicator(_ub_comm); -#endif - comm_created = false; - } - } - - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - torch::Tensor atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts - const int m = (transa) ? A.size(0) : A.size(1); - const int n = _ubuf.size(0); - const int n_chunk = n / _tp_size; - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - - // Create an GEMM output buffer with N+1 chunks in a contiguous memory - torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options()); - D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options()); - - // Get output and workspace data pointers - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - - for (int i = 0; i < _tp_size - 1; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = i; - int recv_chunk_id = i + 1; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == 0) { - _ub_comm->use_ce = 0; - userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, - _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, - true, (cudaStream_t)_stream_recv); - } - } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); - } - if (i == 0) { - te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb, - D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, 0, _tp_size, false, counter); - } - } - - // Store the input activation for backprop - if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); - assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - NVTE_CHECK_CUDA( - cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(), - _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - } - - // Reset atomic counters - consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main); - - // Copy the first GEMM output chunk to the end chunk position of D_buffer - char *src_ptr = reinterpret_cast(D_buffer.data_ptr()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, - n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); - // Return the last N rows of D_buffer - _ub_comm->sms = ori_sms; - torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); - return D_return; - } // atomic_gemm_overlap_ag - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts - const int m = (transa) ? A.size(0) : A.size(1); - const int k = (transa) ? A.size(1) : A.size(0); - const int n_chunk = _ubufs[0].size(0); - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (n_chunk * m) * D.element_size(); - const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; - - // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.data_ptr()); - char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } - if (_aggregate2) { - const int num_steps = _tp_size / 2; - char *input_b_ptr = reinterpret_cast(_ubuf.data_ptr()); - - // Initial 1X input chunk exchange between neighboring peers - int send_chunk_id = _tp_id; - int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); - - int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; - const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; - const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; - - // Ring exchange of 2X inputs chunks - for (int i = 0; i < num_steps; i++) { - send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; - recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; - send_offset = comm_bytes * send_chunk_id; - recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor input_b_chunk = - torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options()); - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk * 2, m}, pre_gelu_out.options()); - } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - if (i < num_steps - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - } - } - } else { - for (int i = 0; i < _tp_size; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; - int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk, m}, pre_gelu_out.options()); - } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, B_type, - transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - if (i < _tp_size - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - } - } - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - _ub_comm->sms = ori_sms; - - return D; - } // split_overlap_ag - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Get communication and GEMM input chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - - // Get input and workspace data pointers - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - - // Atomic GEMM - // Process GEMM chunks in the order that AG+GEMM places the output chunks. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, _ubuf, - D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, - workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, 0, _tp_size, - true, counter); - - // P2P communication chunk - for (int i = 1; i < _tp_size; i++) { - int send_chunk_id = i - 1; - int recv_chunk_id = send_chunk_id + _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - - consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - (cudaStream_t)_stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, - (cudaStream_t)_stream_recv); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } - _ub_comm->sms = ori_sms; - } - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - int k = A.size(1); - int n = B.size(0); - - // Get communication and GEMM input chunk sizes - int n_chunk = n / _tp_size; - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); - - // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } - - // GEMM and send/recv chunks - for (int i = 0; i < _tp_size; i++) { - // GEMM chunk - int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); - // Store the last GEMM chunk output to the recieve buffer. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - - if (i > 0) { - // P2P communication chunk - int send_offset = comm_bytes * (i - 1); - int recv_offset = comm_bytes * (i - 1 + _tp_size); - int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - send_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - recv_rank, (cudaStream_t)_stream_recv); - } - } - at::cuda::setCurrentCUDAStream(stream_main); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } - _ub_comm->sms = ori_sms; - } - - /* - ** Copy input to _ubufs[0] - */ - void copy_input_to_ubuf(torch::Tensor input, bool chunk) { - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { - // Copy input to the target ubuf chunk by rank offset - if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); - } else { - if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); - } - } - - torch::Tensor get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); - if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - } - - void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; - } - - bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } - bool is_atomic_gemm() { return _atomic_gemm; } - bool is_p2p_overlap() { return true; } -}; // UbufP2PCommOverlap - -} // namespace ubuf - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 04a1193a71..175a7b0e90 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -37,12 +38,14 @@ #include #include +#include #include #include #include #include #include #include +#include #include #include "common/util/logging.h" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c30e583178..b039bf2d1b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ +#include + #include "common.h" #include "common/common.h" @@ -504,4 +506,184 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +/*************************************************************************************************** + * Comm+GEMM Overlap Wrappers + **************************************************************************************************/ + +class CommOverlapHelper : torch::CustomClassHolder { + private: + bool initialized{false}; + bool backend_is_nccl{false}; + std::map pgs; + + public: + int myrank = -1; + int numranks = -1; + int mylocal = -1; + int numlocal = -1; + int mynode = -1; + int numnodes = -1; + + CommOverlapHelper(); + + CommOverlapHelper(c10d::ProcessGroup *world_group, + std::optional intra_node_group, + std::optional inter_node_group); + + ~CommOverlapHelper(); + + void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, + ExtComm comm); + + void ub_barrier(ExtComm comm); +}; + +class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { + private: + torch::Tensor _ubuf_torch; + torch::Tensor _ubuf_counter; + + public: + CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + + void set_ubuf_scale_inv(torch::Tensor scale_inv) { + assert(scale_inv.numel()); + assert(scale_inv.scalar_type() == torch::kFloat32); + transformer_engine::CommOverlapBase::set_ubuf_scale_inv( + reinterpret_cast(scale_inv.data_ptr())); + } + + void copy_input_to_ubuf(torch::Tensor input, int comm_type); + + torch::Tensor get_ubuf_output(int comm_type); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + std::vector bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + transformer_engine::CommOverlapType comm_type, at::Tensor rs_output); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, at::Tensor rs_output); +}; // CommOverlap + +class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { + private: + torch::Tensor _ubuf_torch; + torch::Tensor _ubuf_counter; + + public: + CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + transformer_engine::CommOverlapType comm_type, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, + bool use_ce = true, bool aggregate = false); + + void set_ubuf_scale_inv(torch::Tensor scale_inv) { + assert(scale_inv.numel()); + assert(scale_inv.scalar_type() == torch::kFloat32); + transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv( + reinterpret_cast(scale_inv.data_ptr())); + } + + void copy_input_to_ubuf(torch::Tensor input, bool chunk); + + torch::Tensor get_ubuf_output(int comm_type); + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor B_copy); + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor B_copy); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor rs_output); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output); +}; // CommOverlapP2P + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..d212d13516 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -0,0 +1,480 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +using namespace torch::indexing; +using namespace std::placeholders; + +namespace te = transformer_engine; + +#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, B, B_scale_inv, \ + B_fp8_index, B_type, D, D_amax, D_scale, D_type, bias, \ + bias_type, pre_gelu_out, workspace) \ + A = A.contiguous(); \ + void *A_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(A_type)) { \ + assert(A_scale_inv.numel()); \ + A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ + } \ + auto A_ = makeTransformerEngineTensor( \ + A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, \ + nullptr, nullptr, A_scale_inv_ptr); \ + B = B.contiguous(); \ + void *B_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(B_type)) { \ + assert(B_scale_inv.numel()); \ + B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ + } \ + auto B_ = makeTransformerEngineTensor( \ + B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, \ + nullptr, nullptr, B_scale_inv_ptr); \ + void *D_amax_ptr = nullptr; \ + void *D_scale_ptr = nullptr; \ + if (te::is_fp8_dtype(D_type)) { \ + assert(D_amax.numel()); \ + D_amax_ptr = D_amax.data_ptr(); \ + assert(D_scale.numel()); \ + D_scale_ptr = D_scale.data_ptr(); \ + } \ + auto D_ = makeTransformerEngineTensor( \ + D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, \ + D_amax_ptr, D_scale_ptr, nullptr); \ + auto bias_ = makeTransformerEngineTensor( \ + bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); \ + const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ + ? std::vector{static_cast(pre_gelu_out.size(0))} \ + : std::vector{static_cast(pre_gelu_out.size(0)), \ + static_cast(pre_gelu_out.size(1))}; \ + auto pre_gelu_out_ = makeTransformerEngineTensor( \ + pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ + auto workspace_ = makeTransformerEngineTensor( \ + workspace.data_ptr(), std::vector{static_cast(workspace.size(0))}, \ + te::DType::kByte); + +/*************************************************************************************************** + * CommOverlapHelper + **************************************************************************************************/ + +CommOverlapHelper::CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!"); +#endif +} // empty constructor for NVTE_UB_WITH_MPI=1 + +CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, + std::optional intra_domain_group, + std::optional inter_domain_group) { +#ifndef NVTE_UB_WITH_MPI + pgs.insert({"world", world_group}); + myrank = pgs["world"]->getRank(); + numranks = pgs["world"]->getSize(); + c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); + backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); + + if (intra_domain_group.has_value()) { + // Get local rank on node and number of local ranks + NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"intra", intra_domain_group.value()}); + mylocal = pgs["intra"]->getRank(); + numlocal = pgs["intra"]->getSize(); + + if (numlocal == numranks) { + // Intra-node group is same as the world group so there can only be 1 node + NVTE_CHECK( + mylocal == myrank, + "Internal TE error: Local rank must be equal to global rank when intra-node group size ", + "is equal to the world group size!"); + mynode = 0; + numnodes = 1; + } else { + // Intra-node group is different than the world group so there must be multiple nodes + NVTE_CHECK( + inter_domain_group.has_value(), + "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", + "identical to the world_group!"); + + // Get node ID and number of nodes + NVTE_CHECK( + inter_domain_group.value()->getBackendType() == backend, + "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"inter", inter_domain_group.value()}); + mynode = pgs["inter"]->getRank(); + numnodes = pgs["inter"]->getSize(); + } + } else { + // Intra-node group is not set so we assume there is only 1 node + mylocal = myrank; + numlocal = numranks; + pgs.insert({"intra", world_group}); + + mynode = 0; + numnodes = 1; + } + + initialized = true; +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); +#endif +} + +CommOverlapHelper::~CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + for (auto &pg : pgs) pg.second = nullptr; + backend_is_nccl = false; + initialized = false; +#endif +} + +void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, + size_t localbytes, ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + + auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; + auto globaltensor = + torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; + + std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; + std::vector localchunk = {localtmp}; + auto work = pgs[group]->allgather(globalchunks, localchunk); + work->wait(); + + if (backend_is_nccl) { + globaltensor.copy_(globaltmp.cpu()); + globaltmp = torch::Tensor(); + localtmp = torch::Tensor(); + } +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +void CommOverlapHelper::ub_barrier(ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + auto work = pgs[group]->barrier(); + work->wait(); +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +/*************************************************************************************************** + * CommOverlap + **************************************************************************************************/ + +CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool atomic_gemm) + : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, + helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, + num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { + // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to + // for PyTorch to factor externally allocated memory into its memory pool and garbage collection + // threshold calculation. + _ubuf_torch = torch::from_blob( + _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, + at::device(torch::kCUDA).dtype(buffer_dtype)); + if (_atomic_gemm) { + _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); + } +} + +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +std::vector CommOverlap::bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + te::CommOverlapType comm_type, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, + grad, accumulate, use_split_accumulator, comm_type, rs_out_, + stream_main); + + // Get the current userbuf offset + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == te::CommOverlapType::RS) { + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } + + // Generate output tensor from userbuf data pointer + int output_c_dim0 = + (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + auto output_tensor = + torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); + + return {D, output_tensor}; +} // CommOverlap::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlap::atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + gemm_overlap, rs_out_, stream_main); +} // CommOverlap::split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + te::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + te::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, + te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + gemm_overlap, rs_out_, stream_main); +} // CommOverlap::split_overlap_rs + +/* +** Helper function to copy input to _ubuf +*/ +void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type == te::CommOverlapType::AG) { + if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + } + + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); +} + +torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) + NVTE_ERROR("Invalid comm_type"); + if (_comm_type == te::CommOverlapType::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + int output_c_dim0 = + (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); +} + +/*************************************************************************************************** + * CommOverlapP2P + **************************************************************************************************/ + +CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + te::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool atomic_gemm, bool use_ce, bool aggregate) + : te::CommOverlapP2PBase( + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) { + // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to + // for PyTorch to factor externally allocated memory into its memory pool and garbage collection + // threshold calculation. + _ubuf_torch = torch::from_blob( + _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, + at::device(torch::kCUDA).dtype(buffer_dtype)); + if (_atomic_gemm) { + _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); + } +} + +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2P::atomic_gemm_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto B_copy_ = makeTransformerEngineTensor(B_copy); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, + use_split_accumulator, B_copy_, stream_main); +} // atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2P::split_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto B_copy_ = makeTransformerEngineTensor(B_copy); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + B_copy_, stream_main); +} // split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2P::atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, + use_split_accumulator, rs_out_, stream_main); +} + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2P::split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + rs_out_, stream_main); +} + +/* +** Copy input to _ubufs[0] +*/ +void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + if (chunk) { + // Copy input to the target ubuf chunk by rank offset + if (input.numel() != (int64_t)_ubufs[0].numel() || + input.element_size() != (int64_t)_ubufs[0].element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); + } +} + +torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) + NVTE_ERROR("Invalid comm_type"); + if (_comm_type == te::CommOverlapType::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); + int output_c_dim0 = + (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 09b53a8976..7d49a0848b 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -179,7 +179,7 @@ struct AdamFunctorMaster { } }; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -199,10 +199,10 @@ struct AdamFunctor { index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; index_t n = tl.sizes[tensor_loc]; - T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); g += chunk_idx * chunk_size; - T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); @@ -223,10 +223,10 @@ struct AdamFunctor { for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { - r_g[ii] = g[i]; - r_p[ii] = p[i]; - r_m[ii] = m[i]; - r_v[ii] = v[i]; + r_g[ii] = static_cast(g[i]); + r_p[ii] = static_cast(p[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); } else { r_g[ii] = MATH_T(0); r_p[ii] = MATH_T(0); @@ -259,9 +259,9 @@ struct AdamFunctor { for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { - p[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; + p[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } @@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, } } + const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto p_in_type = tensor_lists[1][0].scalar_type(); auto tl_size = tensor_lists.size(); @@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, // Assume single type across p,g,m1,m2 now DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", - multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { // g, p, m, v, p_master - const auto g_in_type = tensor_lists[0][0].scalar_type(); DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( @@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, // Assume single type across p,g,m1,m2 now DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { - const auto g_in_type = tensor_lists[0][0].scalar_type(); DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7bd5a2d8c8..39679ed669 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -4,12 +4,15 @@ * See LICENSE for license information. ************************************************************************/ -#include +#include +#include -#include "../comm_gemm_overlap.h" #include "../extensions.h" +#include "common/util/pybind_helper.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + // Permutation functions m.def("moe_permute_fwd", moe_permute_fwd); m.def("moe_permute_bwd", moe_permute_bwd); @@ -226,90 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); - m.def("device_supports_multicast", &ubuf::device_supports_multicast, - py::call_guard()); - - m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi, - py::call_guard()); - - py::class_(m, "UbufBootstrapCallbacks") - .def(py::init<>(), py::call_guard()) - .def(py::init(), - py::call_guard()); - - py::enum_(m, "UbufOverlapAlgo") - .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) - .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) - .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P) - .value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P) - .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) - .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) - .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); - - // Note: Can't release GIL in constructor since it may bootstrap - // communicator with Python functions (e.g. PyTorch distributed - // communication) - py::class_(m, "UbufCommOverlap") - .def(py::init(), - py::call_guard()) - .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap, - py::call_guard()) - .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs, - py::call_guard()) - .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv, - py::call_guard()) - .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf, - py::call_guard()) - .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output, - py::call_guard()) - .def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap, - py::call_guard()); - - // Note: Can't release GIL in constructor since it may bootstrap - // communicator with Python functions (e.g. PyTorch distributed - // communication) - py::class_(m, "UbufP2PCommOverlap") - .def(py::init(), - py::call_guard()) - .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag, - py::call_guard()) - .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag, - py::call_guard()) - .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output, - py::call_guard()) - .def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf, - py::call_guard()) - .def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap, - py::call_guard()) - .def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv, - py::call_guard()); - - py::enum_(m, "DType", py::module_local()) - .value("kByte", transformer_engine::DType::kByte) - .value("kInt32", transformer_engine::DType::kInt32) - .value("kFloat32", transformer_engine::DType::kFloat32) - .value("kFloat16", transformer_engine::DType::kFloat16) - .value("kBFloat16", transformer_engine::DType::kBFloat16) - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); - py::enum_(m, "FP8FwdTensors") .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) @@ -329,41 +248,61 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + py::class_(m, "CommOverlapHelper") + .def(py::init<>(), py::call_guard()) + .def(py::init, + std::optional>(), + py::call_guard(), py::arg("world_group"), + py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); + py::class_(m, "CommOverlap") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, + int, int, bool, bool>(), + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), + py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, + py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, + py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) + .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) + .def("split_overlap_rs", &CommOverlap::split_overlap_rs, + py::call_guard()) + .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, + py::call_guard()) + .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, + py::call_guard()) + .def("get_ubuf_output", &CommOverlap::get_ubuf_output, + py::call_guard()) + .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, + py::call_guard()) + .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) + .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) + .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); - py::enum_(m, "NVTE_Fused_Attn_Backend") - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); + py::class_(m, "CommOverlapP2P") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, + transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), + py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, + py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, + py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) + .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, + py::call_guard()) + .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, + py::call_guard()) + .def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, + py::call_guard()) + .def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs, + py::call_guard()) + .def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf, + py::call_guard()) + .def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output, + py::call_guard()) + .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, + py::call_guard()) + .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) + .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, + py::call_guard()) + .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, + py::call_guard()); } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 7bb81b8cd4..5ca34f7597 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -11,7 +11,7 @@ import fcntl import struct from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager import torch @@ -88,9 +88,55 @@ def initialize_ub( ub_cfgs: Optional[dict] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: - """Initialize communicators for TP comm overlap using userbuffers.""" + r""" + Initialize the Userbuffers communicator for overlapping tensor-parallel communications with + GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. + + Parameters + ---------- + shape : list + shape of the communication buffer, typically set to be the same as the global shape of + the input tensor to a te.TransformerLayer forward pass, with the sequence and batch + dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` + tp_size : int + number of GPUs in the tensor-parallel process group + use_fp8 : bool = False + allocate the communication buffer for FP8 GEMM inputs/outputs + dtype : torch.dtype = torch.bfloat16 + non-FP8 data type of the communication buffer when `use_fp8 = False` + ub_cfgs: dict = None + Configuration dictionary with the structure + ``` + { + : { + "method": <"ring_exchange" or "pipeline">, + "is_reduce_scatter": bool, + "num_sm": int, + "cga_size": int, + "set_sm_margin": bool, + "num_splits": int, + "aggregate": bool, + "atomic_gemm": bool, + "use_ce": bool, + "fp8_buf": bool, + } + } + ``` + for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", + "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", + "fc2_fprop", "fc2_dgrad"]`. + bootstrap_backend : str = None + `torch.distributed` communication backend for the all-gather, broadcast and + barrier collectives during Userbuffers initialization. Not all backends are + valid for every cluster configuration and distributed launch method even if + they are available in PyTorch. When left unset, the initialization prefers + to use the MPI backend, falling back first on Gloo and then NCCL if MPI is + not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this + option and always initializes Userbuffers with direct MPI calls in C++, + which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. + """ if not tex.device_supports_multicast(): - assert bool(os.getenv("UB_SKIPMC", "0")), ( + assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." ) @@ -100,50 +146,52 @@ def initialize_ub( _ub_communicators = {} if tex.ubuf_built_with_mpi(): - # Userbuffers will ignore all these values when it is built with MPI, so these are just - # placeholders based on an assumption that tp_size covers all devices in a physical node. + # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force + # an MPI_Init() here by creating a new MPI process group... assert torch.distributed.is_mpi_available() - mpi_group = torch.distributed.new_group(backend="mpi") - world_rank = torch.distributed.get_rank(mpi_group) - world_size = torch.distributed.get_world_size(mpi_group) - local_rank = world_rank % tp_size - local_size = tp_size - self_node_idx = world_rank // tp_size - num_nodes = world_size // tp_size - ub_callbacks = tex.UbufBootstrapCallbacks() + _ = torch.distributed.new_group(backend="mpi") + helper = tex.CommOverlapHelper() else: + # Bootstrapping with torch.distributed API, so check backend and construct + # intra/inter-node process groups... assert ( torch.distributed.is_initialized() ), "torch.distributed must be initialized before Userbuffers" if bootstrap_backend is None: bootstrap_backend = "nccl" - if torch.distributed.is_gloo_available(): - bootstrap_backend = "gloo" - elif torch.distributed.is_mpi_available(): + if torch.distributed.is_mpi_available(): bootstrap_backend = "mpi" + elif torch.distributed.is_gloo_available(): + bootstrap_backend = "gloo" else: - assert bootstrap_backend in ["gloo", "mpi", "nccl"] + assert bootstrap_backend in [ + "gloo", + "mpi", + "nccl", + ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" + assert torch.distributed.is_backend_available(bootstrap_backend), ( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " + f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." + ) world_group = torch.distributed.new_group(backend=bootstrap_backend) world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - # Construct an intra-node communicator based on global ranks that share the same hostname - # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host - # address on that interface instead of the hostname. This can help avoid issues when - # different hosts have the same hostname on Kubernetes clusters. - hostname = socket.gethostname() + # We have single-node NVLink so we can color based on physical node hostnames. + # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and + # otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on + # the chosen bootstrap backend. + mydomain = socket.gethostname() ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", - os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME") ) - if ifname is not None: # Make sure the ifname found in the environment is a valid network interface if ifname in [name for _, name in socket.if_nameindex()]: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: - hostname = socket.inet_ntoa( + mydomain = socket.inet_ntoa( fcntl.ioctl( s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) )[20:24] @@ -155,57 +203,64 @@ def initialize_ub( else: ifname_warning = ( f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" - " attempt to " - + "detect ranks on the same node by matching 'socket.gethostname()', which is " - + "known to fail on virtual clusters like Kubernetes. If Userbuffers " - + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " - + "your environment to the correct network interface." + + " attempt to detect ranks on the same node by matching " + + "'socket.gethostname()', which is known to fail on virtual clusters like " + + "Kubernetes. If Userbuffers initialization fails, please set the " + + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network " + + "interface." ) warnings.warn(ifname_warning, UserWarning) - hostnames = [None for _ in range(world_size)] - torch.distributed.all_gather_object(hostnames, hostname, world_group) - unique_hosts = [] - for host in hostnames: - if host not in unique_hosts: - unique_hosts.append(host) - num_nodes = len(unique_hosts) - - if num_nodes > 1: - ranks_per_node_list = [[] for _ in range(num_nodes)] - self_node_idx = -1 - for i, host in enumerate(hostnames): - node_idx = unique_hosts.index(host) - ranks_per_node_list[node_idx].append(i) - if host == hostname: - self_node_idx = node_idx - assert self_node_idx >= 0, "Internal TE error!" - - intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration( - ranks_per_node_list, backend=bootstrap_backend + # Allgather the domain colors across ranks and reduce to a list of unique domains + domain_per_rank_list = [None for _ in range(world_size)] + torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group) + unique_domains = [] + for domain in domain_per_rank_list: + if domain not in unique_domains: + unique_domains.append(domain) + num_domains = len(unique_domains) + + if num_domains > 1: + # DP/TP model replicated on multiple NVLink domains + ranks_per_domain_list = [[] for _ in range(num_domains)] + mydomain_idx = -1 + for i, domain in enumerate(domain_per_rank_list): + domain_idx = unique_domains.index(domain) + ranks_per_domain_list[domain_idx].append(i) + if domain == mydomain: + mydomain_idx = domain_idx + assert mydomain_idx >= 0, "Internal TE error!" + + intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_domain_list, backend=bootstrap_backend + ) + local_rank = torch.distributed.get_rank(intra_domain_group) + intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) + + inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + [list(ranks) for ranks in zip(*ranks_per_domain_list)], + backend=bootstrap_backend, ) - local_rank = torch.distributed.get_rank(intra_node_group) - local_size = torch.distributed.get_world_size(intra_node_group) - intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group) + + helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group) else: - self_node_idx = 0 - intra_node_group = world_group + # TP model on single NVLink domain, no replication, no data-parallelism + mydomain_idx = 0 local_rank = world_rank - local_size = world_size - intra_node_ranks = list(range(world_size)) + intra_domain_ranks = list(range(world_size)) + + helper = tex.CommOverlapHelper(world_group) if world_rank == 0: - print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True) + print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n", + f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", end="", flush=True, ) - ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group) - # Increase the workspace by the number of maximum concurrent streams global _cublas_workspace _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) @@ -304,46 +359,34 @@ def add_ub( if atomic_gemm and method == "ring_exchange": assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message - sample_buffer = torch.empty( - shape, dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, device="cuda" - ) + buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype if method == "ring_exchange": - ub_obj = tex.UbufP2PCommOverlap( - sample_buffer, # Sample userbuffer - world_rank, # World rank - world_size, # World size - local_rank, # Rank within the node - local_size, # Number of ranks/GPUs per node - self_node_idx, # Node ID - num_nodes, # Number of nodes + ub_obj = tex.CommOverlapP2P( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - set_sm_margin, # Set SM margin - aggregate, # Aggregate 2X GEMM chunks - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - is_reduce_scatter, # Overlap with reduce scatter - atomic_gemm, # Use a single GEMM with atomic-counters - use_ce, # Use copy engine for P2P communications - ub_callbacks, + tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + use_ce=use_ce, + aggregate=aggregate, ) else: - ub_obj = tex.UbufCommOverlap( - sample_buffer, # Sample userbuffer - world_rank, # World rank - world_size, # World size - local_rank, # Rank within the node - local_size, # Number of ranks/GPUs per node - self_node_idx, # Node ID - num_nodes, # Number of nodes + ub_obj = tex.CommOverlap( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - num_splits, # Number of communication splits - set_sm_margin, # Set SM margin - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - atomic_gemm, # Use a single GEMM with atomic-counters - ub_callbacks, + num_splits=num_splits, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, ) _ub_communicators[name] = ub_obj @@ -409,15 +452,6 @@ def __init__(self) -> None: self._fp8_workspaces: Dict[str, Float8Tensor] = {} self.activation_dtype: Optional[torch.dtype] = None - # Fast getter for parameters - # Note: torch.nn.Module does not store parameters like normal - # attrs, but rather in a dict. When attempting to access, the - # module will raise an AttributeError in __getattribute__ and - # call a custom __getattr__. This is unnecessary overhead if - # we know we are accessing a parameter. - self._fast_get_param: Callable[str, torch.nn.Parameter] - self._fast_get_param = self.__dict__["_parameters"].get - # Names of attributes that can be set quickly (see __setattr__ # method) _fast_setattr_names: Set[str] = { diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 866aef65a0..08c5addcfc 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -720,7 +720,7 @@ def forward( with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: - weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)] + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fp8: weight_tensors = [ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6dea806993..fbf1b97704 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -161,9 +161,9 @@ def forward( if not return_layernorm_output: ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif parallel_mode == "column" and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -293,7 +293,7 @@ def forward( get_workspace(), bias=bias, use_bias=use_bias, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -485,7 +485,7 @@ def backward( rs_out = None if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(grad_output.size()) @@ -496,14 +496,14 @@ def backward( ) if ub_obj_dgrad.is_p2p_overlap(): if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -616,7 +616,7 @@ def backward( out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -640,7 +640,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -658,7 +658,7 @@ def backward( use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total) @@ -1159,7 +1159,7 @@ def forward( with self.prepare_forward(inp, is_first_microbatch) as inp: # Get concatenated weight and bias tensors - unfused_weights = [self._fast_get_param(name) for name in self.weight_names] + unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: @@ -1170,9 +1170,9 @@ def forward( unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names]) + bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) else: - bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused + bias_tensor = getattr(self, self.bias_names[0]) # Unused # Initialize FP8 weights if needed weight_fp8 = None @@ -1206,8 +1206,8 @@ def forward( args = [None] args += ( inp, - self._fast_get_param("layer_norm_weight"), - self._fast_get_param("layer_norm_bias"), + self.layer_norm_weight, + self.layer_norm_bias, weight_tensor, weight_fp8, bias_tensor, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c1633111d..64e8c9ce36 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -180,9 +180,9 @@ def forward( ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): - ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo_ag = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo_ag = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif set_parallel_mode and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -298,14 +298,14 @@ def forward( rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) if ub_obj_fc2out.is_p2p_overlap(): if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_fc2out.is_fp8_ubuf(): fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT @@ -369,7 +369,7 @@ def forward( bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, gelu=not bias_gelu_nvfusion and (activation == "gelu"), - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -410,9 +410,9 @@ def forward( dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) if ub_obj_fc2out.is_p2p_overlap(): - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) @@ -615,9 +615,9 @@ def backward( dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("fc2_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( @@ -788,7 +788,7 @@ def backward( # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap rs_out = None if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(dgelu.size()) @@ -797,14 +797,14 @@ def backward( rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) if ub_obj_dgrad.is_p2p_overlap(): if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -842,7 +842,7 @@ def backward( grad=True, gelu_input=fc1_out, ub_algo=( - tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None + tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None ), ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) @@ -892,7 +892,7 @@ def backward( # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(dgelu.size()) @@ -900,9 +900,9 @@ def backward( dim_size[1] = fc1_weight.size(1) rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) if ub_obj_dgrad.is_p2p_overlap(): - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -967,7 +967,7 @@ def backward( out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -991,7 +991,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -1009,7 +1009,7 @@ def backward( use_bias=not ctx.bias_gelu_nvfusion, accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total, dgelu) @@ -1491,10 +1491,10 @@ def forward( with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: # Get weight tensors - fc1_weight = self._fast_get_param("fc1_weight") - fc1_bias = self._fast_get_param("fc1_bias") - fc2_weight = self._fast_get_param("fc2_weight") - fc2_bias = self._fast_get_param("fc2_bias") + fc1_weight = self.fc1_weight + fc1_bias = self.fc1_bias + fc2_weight = self.fc2_weight + fc2_bias = self.fc2_bias if not self.fp8: if isinstance(fc1_weight, Float8Tensor): fc1_weight = fc1_weight.from_float8() @@ -1555,8 +1555,8 @@ def forward( args = [None] args += ( inp, - self._fast_get_param("layer_norm_weight"), - self._fast_get_param("layer_norm_bias"), + self.layer_norm_weight, + self.layer_norm_bias, fc1_weight, fc1_weight_fp8, fc1_bias, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f521cf4fb6..1fed467210 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -190,14 +190,14 @@ def forward( rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_projout.is_fp8_ubuf(): proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] @@ -269,9 +269,9 @@ def forward( dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features @@ -407,9 +407,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ( grad_output, @@ -496,7 +496,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], layout="NN", grad=True, ub_algo=( - tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None ), @@ -950,7 +950,7 @@ def forward( ) as inp: # Get concatenated weight and bias tensors - unfused_weights = [self._fast_get_param(name) for name in self.weight_names] + unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: @@ -961,9 +961,9 @@ def forward( unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names]) + bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) else: - bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused + bias_tensor = getattr(self, self.bias_names[0]) # Unused # Initialize FP8 weights if needed weight_fp8 = None diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 191c98745d..93f6191dfe 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -3,11 +3,15 @@ # See LICENSE for license information. """Fused Adam optimizer.""" +from copy import deepcopy +from itertools import chain + import torch import transformer_engine_torch as tex from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from .multi_tensor_apply import multi_tensor_applier +from ..float8_tensor import Float8Tensor def get_fp8_meta(fp8_tensor): @@ -68,11 +72,28 @@ class FusedAdam(torch.optim.Optimizer): method is called. (default: True) capturable (bool, optional): whether to use the version of the optimizer that can be used with CUDA Graphs. (default: False) - master_weights (list of torch.Tensor, optional): master weights to use - for mixed precision training. If provided, the optimizer will update - the master weights and then cast the master weights to the model weights. - If not provided, the optimizer will update the model weights directly. - (default: None) + master_weights (bool, optional): whether to maintain FP32 master weights + in the optimizer with FP16/BF16 mixed precision training. + (default: False) + master_weight_dtype (torch.dtype, optional): The dtype of master weights. + If master_weights is False, this will be ignored. It can be one of + [torch.float32, torch.float16]. If it's not torch.float32, the optimizer + will create a FP32 scalar scaling factor to ensure precision. + (default: torch.float32) + exp_avg_dtype (torch.dtype, optional): The dtype of exp_avg. It can be + one of [torch.float32, torch.float16, torch.uint8], where torch.uint8 + represents FP8. If it's not torch.float32, the optimizer will create + a FP32 scalar scaling factor to ensure precision. + (default: torch.float32) + exp_avg_sq_dtype (torch.dtype, optional): The dtype of exp_avg_sq. It + can be one of [torch.float32, torch.float16, torch.uint8], where + torch.uint8 represents FP8. If it's not torch.float32, the optimizer + will create a FP32 scalar scaling factor to ensure precision. + (default: torch.float32) + use_decoupled_grad (bool, optional): Whether to use ".decoupled_grad" + instead of ".grad" for reading gradients. It's useful when the dtypes + of grad and param are different. + (default: False) .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -92,12 +113,36 @@ def __init__( amsgrad=False, set_grad_none=True, capturable=False, - master_weights=None, + master_weights=False, + master_weight_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + use_decoupled_grad=False, ): if amsgrad: raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + # Add constraints to dtypes of states. + if master_weights and master_weight_dtype not in [torch.float32, torch.float16]: + raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.") + if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.") + if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.") + + # Currently, capturable mode only supports fp32 master weights and optimizer states. + # The reason is, if the master weights or optimizer states are not in fp32 dtype, + # they will be copied to temporary fp32 buffers first. These fp32 buffers are then + # used as inputs for the kernel. Consequently, the pointer for earch `.step()` differs, + # making CUDA Graph inapplicable in this scenario. + if capturable and master_weights and master_weight_dtype != torch.float32: + raise RuntimeError("Capturable mode only supports fp32 master weights.") + if capturable and exp_avg_dtype != torch.float32: + raise RuntimeError("Capturable mode only supports fp32 exp_avg.") + if capturable and exp_avg_sq_dtype != torch.float32: + raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq") + # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr defaults = { @@ -112,9 +157,6 @@ def __init__( self.set_grad_none = set_grad_none self.capturable = capturable - - if master_weights is not None: - assert isinstance(master_weights, list), "master_weights must be a list if provided" self.master_weights = master_weights if capturable: @@ -134,14 +176,208 @@ def __init__( self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master + self.master_weight_dtype = master_weight_dtype + self.exp_avg_dtype = exp_avg_dtype + self.exp_avg_sq_dtype = exp_avg_sq_dtype + self.name_to_dtype_map = { + "exp_avg": self.exp_avg_dtype, + "exp_avg_sq": self.exp_avg_sq_dtype, + "master_param": self.master_weight_dtype, + } + self.dtype_to_range_map = { + torch.float16: torch.full( + [1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32 + ), + torch.uint8: torch.full([1], 448.0, dtype=torch.float32), + } + self._scales = {} + self.use_decoupled_grad = use_decoupled_grad + def zero_grad(self): # pylint: disable=missing-function-docstring - if self.set_grad_none: - for group in self.param_groups: - for p in group["params"]: + if not self.use_decoupled_grad and not self.set_grad_none: + super().zero_grad() + return + + for group in self.param_groups: + for p in group["params"]: + if self.use_decoupled_grad and self.set_grad_none: + p.decoupled_grad = None + elif self.use_decoupled_grad and not self.set_grad_none: + p.decoupled_grad.zero_() + elif not self.use_decoupled_grad and self.set_grad_none: p.grad = None + + def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): + """Apply scaling on `unscaled_state`. `scaled_state` and `scale` will be written inplace. + + Arguments: + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + unscaled_state (torch.Tensor): An unscaled high-precision tensor. + scaled_state (torch.Tensor): An scaled low-precision tensor. + scale (torch.Tensor): A FP32 tensor representing the scaling factor. + """ + assert unscaled_state.dtype == torch.float32 + dtype = self.name_to_dtype_map[state_name] + if dtype == torch.uint8: + assert isinstance(scaled_state, Float8Tensor) else: - super().zero_grad() + assert scaled_state.dtype == dtype + + max_range = self.dtype_to_range_map[dtype] + if max_range.device != scaled_state.device: + max_range = max_range.to(scaled_state.device) + self.dtype_to_range_map[scaled_state.dtype] = max_range + if unscaled_state.device != scaled_state.device: + unscaled_state = unscaled_state.to(scaled_state.device) + min_val, max_val = torch.aminmax(unscaled_state) + absmax = torch.maximum(-min_val, max_val) + absmax = absmax.to(dtype=torch.float32, device=unscaled_state.device) + torch.div(absmax, max_range, out=scale) + if isinstance(scaled_state, Float8Tensor): + scaled_state._scale_inv.copy_(scale) + scaled_state.copy_(unscaled_state) + else: + rscale = torch.where(scale > 0, scale.reciprocal(), 0.0) + unscaled_state.mul_(rscale) + scaled_state.copy_(unscaled_state) + + def get_unscaled_state(self, param, state_name): + """Return the unscaled state corresponding to the input `param` and `state_name`. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + """ + state = self.state[param] + dtype = self.name_to_dtype_map[state_name] + if dtype == torch.uint8: + assert isinstance(state[state_name], Float8Tensor) + unscaled = state[state_name].float() + elif dtype == torch.float16: + assert state[state_name].dtype == torch.float16 + unscaled = state[state_name].float() + unscaled.mul_(self._scales[param][state_name]) + elif dtype == torch.float32: + assert state[state_name].dtype == torch.float32 + unscaled = state[state_name] + else: + raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") + return unscaled + + def set_scaled_state(self, param, state_name, unscaled_state): + """Set the optimizer state. + + If the dtype of the corresponding optimizer state is not FP32, + it will do scaling automatically. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + unscaled_state (torch.Tensor): The original high-precision(FP32) state. + """ + assert unscaled_state.dtype == torch.float32 + state = self.state[param] + if state_name not in state: + self._initialize_state(param, state_name, False) + + dtype = self.name_to_dtype_map[state_name] + if dtype != torch.float32: + scale = self._scales[param] + self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name]) + else: + state[state_name].copy_(unscaled_state) + + def _initialize_state(self, param, state_name, zero_buffer: bool): + """Initialize one of the optimizer states according to `state_name`. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + zero_buffer (bool): Whether to initialize the optimizer state with zeros. + """ + dtype = self.name_to_dtype_map[state_name] + data = torch.empty_like(param, dtype=dtype) + if zero_buffer: + data.zero_() + + if dtype == torch.uint8: + self.state[param][state_name] = Float8Tensor( + data=data, + dtype=torch.float32, + fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device), + ) + else: + self.state[param][state_name] = data + + # Create scale if necessary. + if dtype != torch.float32: + if param not in self._scales: + self._scales[param] = {} + self._scales[param][state_name] = torch.ones( + [1], dtype=torch.float32, device=param.device + ) + + def initialize_state(self, param): + """Initialize optimizer states. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + """ + self._initialize_state(param, "exp_avg", zero_buffer=True) + self._initialize_state(param, "exp_avg_sq", zero_buffer=True) + if self.master_weights: + self._initialize_state(param, "master_param", zero_buffer=False) + self.set_scaled_state(param, "master_param", param.clone().detach().float()) + + def state_dict(self): + """Override the state_dict() of pytorch. Before returning the state_dict, cast all + non-fp32 states to fp32. + """ + state_dict = super().state_dict() + + groups = self.param_groups + saved_groups = deepcopy(state_dict["param_groups"]) + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + new_v = {} + for name in v: + new_v[name] = self.get_unscaled_state(param, name) + state_dict["state"][k] = new_v + + return state_dict + + def load_state_dict(self, state_dict): + """Override the load_state_dict() of pytorch. Since pytorch's load_state_dict forces the + state to be the same dtype as param, We need to manully set the state again. + """ + super().load_state_dict(state_dict) + + groups = self.param_groups + saved_groups = deepcopy(state_dict["param_groups"]) + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + self.state[param] = {} + for name in v: + self.set_scaled_state(param, name, v[name].float()) def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. @@ -156,8 +392,6 @@ def step(self, closure=None, grad_scaler=None): if closure is not None: loss = closure() - master_param_idx = 0 - for group in self.param_groups: if len(group["params"]) == 0: continue @@ -196,6 +430,11 @@ def step(self, closure=None, grad_scaler=None): amaxes = [] scale_invs = [] + # Lists for scaling + unscaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []} + scaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []} + state_scales = {"exp_avg": [], "exp_avg_sq": [], "master_param": []} + # Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is. out_dtype = tex.DType.kFloat32 @@ -207,31 +446,29 @@ def step(self, closure=None, grad_scaler=None): # State initialization if len(state) == 0: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p.data).float() - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p.data).float() - # Master weights - if self.master_weights and p.dtype != torch.float32: - # model weights can be fp32/bf16/fp16/fp8 - # If it's fp32, it has no corresponding master weights - state["master_param"] = self.master_weights[master_param_idx] - master_param_idx += 1 - assert ( - state["master_param"].shape == p.shape - ), "Master weights shape must match model weights shape" - - p_master = state.get("master_param", None) - p_grad = p.grad - - if self.master_weights and p_master is not None and p_master.grad is not None: - p_grad = p_master.grad + self.initialize_state(p) + + if self.use_decoupled_grad: + p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None + else: + p_grad = p.grad if p_grad is None: continue if p_grad.data.is_sparse: raise RuntimeError("FusedAdam does not support sparse gradients.") + # Unscaling + unscaled_state = {} + for name in ["exp_avg", "exp_avg_sq", "master_param"]: + if name in state: + unscaled = self.get_unscaled_state(p, name) + unscaled_state[name] = unscaled + if self.name_to_dtype_map[name] != torch.float32: + unscaled_lists[name].append(unscaled) + scaled_lists[name].append(state[name]) + state_scales[name].append(self._scales[p][name]) + if isinstance(p, Float8Tensor): out_dtype = p._fp8_dtype p_fp8_model.append(p._data.data) @@ -240,26 +477,28 @@ def step(self, closure=None, grad_scaler=None): amaxes.append(amax) scale_invs.append(scale_inv) if self.master_weights: - p_main_of_fp8_model.append(p_master.data) + p_main_of_fp8_model.append(unscaled_state["master_param"].data) g_of_fp8_model.append(p_grad.data) - m_of_fp8_model.append(state["exp_avg"]) - v_of_fp8_model.append(state["exp_avg_sq"]) + m_of_fp8_model.append(unscaled_state["exp_avg"]) + v_of_fp8_model.append(unscaled_state["exp_avg_sq"]) elif p.dtype in [torch.float16, torch.bfloat16]: has_fp16 = has_fp16 or p.dtype == torch.float16 has_bf16 = has_bf16 or p.dtype == torch.bfloat16 p_f16_model.append(p.data) if self.master_weights: - p_main_of_f16_model.append(p_master.data) + p_main_of_f16_model.append(unscaled_state["master_param"].data) g_of_f16_model.append(p_grad.data) - m_of_f16_model.append(state["exp_avg"]) - v_of_f16_model.append(state["exp_avg_sq"]) + m_of_f16_model.append(unscaled_state["exp_avg"]) + v_of_f16_model.append(unscaled_state["exp_avg_sq"]) elif p.dtype == torch.float32: p_f32_model.append(p.data) g_of_f32_model.append(p_grad.data) - m_of_f32_model.append(state["exp_avg"]) - v_of_f32_model.append(state["exp_avg_sq"]) + m_of_f32_model.append(unscaled_state["exp_avg"]) + v_of_f32_model.append(unscaled_state["exp_avg_sq"]) else: - raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8") + raise RuntimeError( + "FusedAdam only support model weights in fp32, fp16, bf16 and fp8" + ) if self.capturable and len(p_fp8_model) > 0: raise RuntimeError( @@ -389,4 +628,15 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + # Scaling + for name in ["exp_avg", "exp_avg_sq", "master_param"]: + if len(unscaled_lists[name]) > 0: + for unscaled, scaled, scale in zip( + unscaled_lists[name], scaled_lists[name], state_scales[name] + ): + self._apply_scale(name, unscaled, scaled, scale) + + # Try to reclaim the temporary fp32 buffers. + del unscaled_lists + return loss