Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into fused_out_correction
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyao0115 authored Nov 19, 2024
2 parents e33b2a2 + 994f19d commit 35c9836
Show file tree
Hide file tree
Showing 89 changed files with 7,766 additions and 1,936 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
name: 'PaddlePaddle'
runs-on: ubuntu-latest
container:
image: nvcr.io/nvidia/paddlepaddle:24.07-py3
image: nvcr.io/nvidia/paddlepaddle:24.10-py3
options: --user root
steps:
- name: 'Checkout'
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
|| github.actor == 'pggPL'
|| github.actor == 'vasunvidia'
|| github.actor == 'erhoo82'
|| github.actor == 'kocchop'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.13.0.dev0
1.14.0.dev0
2 changes: 1 addition & 1 deletion build_tools/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setup_paddle_extension(
# Source files
csrc_source_files = Path(csrc_source_files)
sources = [
csrc_source_files / "extensions.cu",
csrc_source_files / "extensions.cpp",
csrc_source_files / "common.cpp",
csrc_source_files / "custom_ops.cu",
]
Expand Down
2 changes: 1 addition & 1 deletion build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setup_pytorch_extension(
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cu",
csrc_source_files / "common.cpp",
csrc_source_files / "ts_fp8_op.cpp",
] + all_files_in_dir(extensions_dir)

Expand Down
2 changes: 2 additions & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt

pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
9 changes: 8 additions & 1 deletion qa/L1_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,11 @@
set -xe

: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*

# Skip ring attention tests since they need fixed environment vars
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'

# Test ring attention with and without scan loop
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
30 changes: 29 additions & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose,
_jax_cast_transpose,
_jax_dbias_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex
Expand Down Expand Up @@ -504,7 +505,6 @@ def _prim_func_bwd(ctx, g):
scale_inv,
FP8Helper.BWD_DTYPE,
-1,
-2,
self.activation_type,
)
)
Expand Down Expand Up @@ -812,6 +812,34 @@ def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)

@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_dbias_cast_transpose(
input, amax, scale, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
Expand Down
164 changes: 115 additions & 49 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource

# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
Expand Down Expand Up @@ -133,7 +135,6 @@ def test_self_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -268,7 +269,6 @@ def test_cross_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -335,6 +335,36 @@ def ref_func(query, kv, mask):
)


@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
class TestDistributedContextParallelSelfAttn:

def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
Expand Down Expand Up @@ -372,37 +402,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args

@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
def test_contex_parallel_self_attn(
def impl_test_contex_parallel_attn(
self,
device_count,
mesh_shape,
Expand All @@ -414,6 +414,7 @@ def test_contex_parallel_self_attn(
dtype,
qkv_layout,
load_balanced,
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
Expand All @@ -425,22 +426,32 @@ def test_contex_parallel_self_attn(
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None,
) # no SWA for CP

# For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")

if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
Expand All @@ -461,6 +472,7 @@ def target_func(q, k, v, mask):
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
Expand Down Expand Up @@ -566,6 +578,60 @@ def grad_func(func, *args, **kwargs):

assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)

def test_contex_parallel_allgather_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
)

def test_context_parallel_ring_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.RING,
)


class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial
from math import sqrt
from typing import Tuple, Optional
import random

import jax
import jax.numpy as jnp
Expand Down
40 changes: 40 additions & 0 deletions tests/jax/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
from functools import partial
import os

from transformer_engine.jax.cpp_extensions.misc import get_xla_flag


@pytest.fixture(autouse=True, scope="function")
def preserve_xla_flags():
"""Ensures the XLA flags environment variable is restored after any tests in this file run."""
old_flags = os.getenv("XLA_FLAGS")
yield
if old_flags is not None:
os.environ["XLA_FLAGS"] = old_flags


def test_get_xla_flag(request):
os.environ["XLA_FLAGS"] = ""
assert get_xla_flag("") is None
assert get_xla_flag("--foo") is None
assert get_xla_flag("--bar=1") is None

os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
assert get_xla_flag("--foo") == True
assert get_xla_flag("--bar") == "1"
assert get_xla_flag("--bar", cast=int) == 1
assert get_xla_flag("--bar", cast=bool) == True
assert get_xla_flag("--baz") == "biz"
with pytest.raises(ValueError):
# cast will fail
assert get_xla_flag("--baz", cast=int)
assert get_xla_flag("--xla") is None

os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
assert get_xla_flag("--xla_abc") == True
assert get_xla_flag("--xla_abb") == True
Loading

0 comments on commit 35c9836

Please sign in to comment.