Skip to content

[release/2.7][ROCm][TunableOp] TunableOp TF32 support #2049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
96316ce50fade7e209553aba4898cd9b82aab83b
3dbfe504e9cd7ba3dec5f63938ed2715c2ba2555
1 change: 1 addition & 0 deletions .github/scripts/build_triton_wheel.py
Original file line number Diff line number Diff line change
@@ -68,6 +68,7 @@ def build_triton(
triton_repo = "https://github.com/openai/triton"
if device == "rocm":
triton_pkg_name = "pytorch-triton-rocm"
triton_repo = "https://github.com/ROCm/triton"
elif device == "xpu":
triton_pkg_name = "pytorch-triton-xpu"
triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
13 changes: 11 additions & 2 deletions aten/src/ATen/cuda/tunable/GemmHipblaslt.h
Original file line number Diff line number Diff line change
@@ -498,7 +498,11 @@ class HipblasltGemmOp : public Callable<ParamsT> {
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
}

HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
if (at::globalContext().allowTF32CuBLAS()) {
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
}
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);

@@ -611,6 +615,11 @@ auto GetHipBlasLtTypeStringAndOps() {
auto in_out_datatype = HipDataTypeFor<CT>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;

hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
if (at::globalContext().allowTF32CuBLAS()) {
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
}

hipblasLtHandle_t handle;
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
@@ -621,7 +630,7 @@ auto GetHipBlasLtTypeStringAndOps() {
b_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
computeType,
heuristic_result));
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));

4 changes: 4 additions & 0 deletions aten/src/ATen/cuda/tunable/GemmRocblas.h
Original file line number Diff line number Diff line change
@@ -141,6 +141,8 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {

TuningStatus Call(const GemmParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
@@ -207,6 +209,8 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>

TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
6 changes: 5 additions & 1 deletion aten/src/ATen/cuda/tunable/TunableGemm.h
Original file line number Diff line number Diff line change
@@ -145,7 +145,11 @@ inline const char* TypeName(T v) {

template <>
inline const char* TypeName(float v) {
return "float";
if (at::globalContext().allowTF32CuBLAS()) {
return "tf32";
} else {
return "float";
}
}

template <>
198 changes: 198 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@
from torch.distributions.binomial import Binomial
import torch.backends.opt_einsum as opt_einsum
import operator
import contextlib

# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
@@ -96,7 +97,30 @@ def get_tunableop_validators():
validators[key] = value
return validators

def find_tunableop_result(results, OpSig, ParamSig):
assert isinstance(results, tuple)
for inner_tuple in results:
if OpSig in inner_tuple and ParamSig in inner_tuple:
return inner_tuple
return None

class TestLinalg(TestCase):
@contextlib.contextmanager
def _hip_allow_tf32(self):
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+. Environment variable will be removed in the future.
import os
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"

try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]

def setUp(self):
super(self.__class__, self).setUp()
torch.backends.cuda.matmul.allow_tf32 = False
@@ -5465,6 +5489,180 @@ def test_scaled_gemm_tunableop(self, device, dtype):
except FileNotFoundError:
pass

@onlyCUDA
@skipCUDAIfNotRocm
@runOnRocmArch(MI300_ARCH)
@dtypes(torch.float)
def test_tf32_tunableop(self, device, dtype):
# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+. Eventually this flag will go away.
import os

tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext

try:
with tf32_ctx():
torch.backends.cuda.matmul.allow_tf32 = True
set_tunableop_defaults()
torch.cuda.tunable.set_rotating_buffer_size(0)
torch.cuda.tunable.enable()

# Reference number of results
ref_num_results = len(torch.cuda.tunable.get_results())

N = M = K = 37
A = torch.randn(N, K, device=device, dtype=dtype)
B = torch.randn(K, M, device=device, dtype=dtype)
C = torch.matmul(A, B)

# This stores total number of cumulative results
total_num_results = len(torch.cuda.tunable.get_results())

# There must be a new tuning result
self.assertEqual((total_num_results - ref_num_results), 1)

# The results must NOT be from rocBLAS
# result can be either Default or Hipblaslt
# Additionally, the Op Signature must be tf32
last_result = torch.cuda.tunable.get_results()
found_result = find_tunableop_result(last_result,
'GemmTunableOp_tf32_NN',
'nn_37_37_37_ld_37_37_37')
self.assertTrue(found_result is not None)
self.assertTrue('Rocblas' not in found_result)


# Now disable TF32
torch.backends.cuda.matmul.allow_tf32 = False

# Update the number of reference results
ref_num_results = total_num_results

# Tune the same GEMM again
C = torch.matmul(A, B)

# This stores total number of cumulative results
total_num_results = len(torch.cuda.tunable.get_results())

# There must be a new tuning result
self.assertEqual((total_num_results - ref_num_results), 1)

# The new tuning result must be of type float
last_result = torch.cuda.tunable.get_results()
found_result = find_tunableop_result(last_result,
'GemmTunableOp_float_NN',
'nn_37_37_37_ld_37_37_37')
self.assertTrue(found_result is not None)

finally:
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False

# disable TunableOp
torch.cuda.tunable.enable(False)

try:
filename = torch.cuda.tunable.get_filename()
os.remove(filename)
except FileNotFoundError:
pass

@onlyCUDA
@skipCUDAIfNotRocm
@runOnRocmArch(MI300_ARCH)
@dtypes(torch.float)
def test_tf32_offline_tunableop(self, device, dtype):
# This test is the offline version of test_tf32_tunableop
import os

ordinal = torch.cuda.current_device()

# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+. Eventually this flag will go away.
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext

# Test in try-finally block to avoid leaking state
# if test is interrupted.
try:
with tf32_ctx():
torch.backends.cuda.matmul.allow_tf32 = True
set_tunableop_defaults()
torch.cuda.tunable.set_rotating_buffer_size(0)

result_filename = f"tunableop_results{ordinal}.csv"
os.putenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME", "tunableop_untuned.csv")
torch.cuda.tunable.set_filename(result_filename)

torch.cuda.tunable.enable()
# record GEMM
torch.cuda.tunable.tuning_enable(False)
torch.cuda.tunable.record_untuned_enable(True)
self.assertTrue(torch.cuda.tunable.record_untuned_is_enabled())

N = M = K = 41
A = torch.randn(N, K, device=device, dtype=dtype)
B = torch.randn(K, M, device=device, dtype=dtype)
C = torch.matmul(A, B)

# Now disable TF32
torch.backends.cuda.matmul.allow_tf32 = False
C = torch.matmul(A, B)

untuned_filename = f"tunableop_untuned{ordinal}.csv"
self.assertTrue(os.path.exists(untuned_filename))

# tuning the untuned GEMMs in file
torch.cuda.tunable.tuning_enable(True)
torch.cuda.tunable.record_untuned_enable(False)

# set these to single iterations to keep it short but still exercise the code
torch.cuda.tunable.set_max_tuning_duration(1)
torch.cuda.tunable.set_max_tuning_iterations(1)

ref_results = len(torch.cuda.tunable.get_results())
torch.cuda.tunable.tune_gemm_in_file(untuned_filename)
new_results = len(torch.cuda.tunable.get_results())

# This stores total number of cummulative results
total_num_results = new_results - ref_results

# There must be a new tuning results
self.assertEqual(total_num_results, 2)

last_result = torch.cuda.tunable.get_results()
found_result = find_tunableop_result(last_result,
'GemmTunableOp_tf32_NN',
'nn_41_41_41_ld_41_41_41')
self.assertTrue(found_result is not None)

found_result = find_tunableop_result(last_result,
'GemmTunableOp_float_NN',
'nn_41_41_41_ld_41_41_41')
self.assertTrue(found_result is not None)

finally:
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False

# disable TunableOp
torch.cuda.tunable.enable(False)

# undo all the environment variables set
try:
del os.environ["PYTORCH_TUNABLEOP_UNTUNED_FILENAME"]
except KeyError:
pass

# clean up, remove any files that were generated
for filename in [untuned_filename, result_filename]:
try:
os.remove(filename)
# NB: The file is locked on Windows
except (FileNotFoundError, PermissionError):
pass

@dtypes(torch.float, torch.complex64)
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
7 changes: 7 additions & 0 deletions torch/cuda/tunable.py
Original file line number Diff line number Diff line change
@@ -444,6 +444,7 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:

dtype_dict = {
"float": torch.float32,
"tf32": torch.float32,
"double": torch.float64,
"BFloat16": torch.bfloat16,
"Half": torch.half,
@@ -470,6 +471,12 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
transA = layout[0] == "T"
transB = layout[1] == "T"
dtype = dtype_dict.get(data_type)
if data_type == "tf32":
# User must still set HIPBLASLT_ALLOW_TF32=1
torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = False

else: # ScaledGEMM
untuned_gemm_temp = untuned_gemm[0].split("_")
# dtypeC = might not be FP8 type, keep track