Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions gmm2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import time
import torch
import transformer_engine.pytorch as te

torch.manual_seed(0)

os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1"
os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1"

device = "cuda"
dtype = torch.bfloat16

E = 4
K = 1024
N = 2048
m_splits = [128, 64, 0, 256]
M_total = sum(m_splits)

x = torch.randn(M_total, K, device=device, dtype=dtype)

# Timing helper
def bench_cuda(fn, warmup=20, iters=100, name=""):
# Warmup
for _ in range(warmup):
fn()
torch.cuda.synchronize()

# Timed
start = time.time()
for _ in range(iters):
fn()
torch.cuda.synchronize()
end = time.time()

avg_ms = (end - start) * 1000.0 / iters
if name:
print(f"{name}: {avg_ms:.3f} ms (avg over {iters} runs, {warmup} warmup)")
return avg_ms

# TE GroupedLinear
glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype)

def te_run():
return glinear(x, m_splits=m_splits)

te_ms = bench_cuda(te_run, warmup=20, iters=100, name="TE GroupedLinear")

# Grab weights for reference path
Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K]
W = torch.stack(Ws, dim=0) # [E, N, K]
assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}"

# Torch reference (group loop)
offsets = []
off = 0
for m in m_splits:
offsets.append(off)
off += m

y_ref_buf = torch.empty((M_total, N), device=device, dtype=dtype)

def torch_run():
# Fill the preallocated buffer
for e, m in enumerate(m_splits):
if m == 0:
continue
o = offsets[e]
y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1))
return y_ref_buf

torch_ms = bench_cuda(torch_run, warmup=20, iters=100, name="Torch loop (prealloc out)")

# Compare outputs
y_te = te_run()
y_ref = torch_run().clone()

diff = (y_te.float() - y_ref.float())
max_abs = diff.abs().max().item()
rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item()

print(f"\nErrors:")
print(f" {y_te.shape=}, {y_ref.shape=}")
print(" max_abs_err:", max_abs)
print(" max_rel_err:", rel)

torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2)

print(f"\nTiming:")
print(f" TE avg: {te_ms:.3f} ms")
print(f" Torch avg: {torch_ms:.3f} ms")
print(f" Speedup: {torch_ms/te_ms:.2f}x (Torch / TE)")
16 changes: 11 additions & 5 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")

te_linear_ref = Linear(
config.hidden_size,
Expand Down Expand Up @@ -1677,7 +1677,7 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute(
):
if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")
config = model_configs[model]

ln_linear_ref = LayerNormLinear(
Expand Down Expand Up @@ -1891,7 +1891,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")

ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
Expand Down Expand Up @@ -2036,7 +2036,7 @@ def test_grouped_linear_accuracy(

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")

Expand Down Expand Up @@ -2115,7 +2115,7 @@ def test_grouped_linear_accuracy(


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("dtype", param_types, ids=str)
Expand All @@ -2133,6 +2133,9 @@ def test_grouped_linear_accuracy_cutlass(
delay_wgrad_compute,
):
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
if IS_HIP_EXTENSION:
os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1"
os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1"
test_grouped_linear_accuracy(
dtype,
num_gemms,
Expand All @@ -2147,6 +2150,9 @@ def test_grouped_linear_accuracy_cutlass(
use_cutlass=True,
)
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
if IS_HIP_EXTENSION:
os.environ.pop("NVTE_USE_CK_GROUPED_GEMM", None)
os.environ.pop("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", None)


@pytest.mark.parametrize("dtype", param_types, ids=str)
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@ endif()
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")

set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel)

target_include_directories(transformer_engine
BEFORE PRIVATE
${CK_ROOT}/include
)


if (USE_CUDA)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
Expand Down
Loading
Loading