Skip to content

Commit 20ad852

Browse files
petrexjeffdaily
authored andcommitted
[ROCm] OCP FP8 Support for new GPUs (pytorch#146632)
TLDR: Follow up/ Build on top of pytorch#144476. add OCP FP8 support for gfx950 refer to pytorch/ao#1677 This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks. * [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. * [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876) * [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments. * [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3. * [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182) These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks. Pull Request resolved: pytorch#146632 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
1 parent bf084c3 commit 20ad852

File tree

12 files changed

+114
-53
lines changed

12 files changed

+114
-53
lines changed

aten/src/ATen/Context.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,12 @@ at::BlasBackend Context::blasPreferredBackend() {
318318
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
319319
static const bool hipblaslt_unsupported = []() {
320320
static const std::vector<std::string> archs = {
321-
"gfx90a", "gfx940", "gfx941", "gfx942",
321+
"gfx90a", "gfx942"
322322
#if ROCM_VERSION >= 60300
323-
"gfx1100", "gfx1101"
323+
, "gfx1100", "gfx1101", "gfx1200", "gfx1201"
324+
#endif
325+
#if ROCM_VERSION >= 60500
326+
, "gfx950"
324327
#endif
325328
};
326329
for (auto index: c10::irange(getNumGPUs())) {

aten/src/ATen/cuda/CUDADataType.h

+1-8
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,17 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
7878
return CUDA_R_64I;
7979
case c10::ScalarType::BFloat16:
8080
return CUDA_R_16BF;
81-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
81+
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || (defined(USE_ROCM) && ROCM_VERSION >= 60300)
8282
case c10::ScalarType::Float8_e4m3fn:
8383
return CUDA_R_8F_E4M3;
8484
case c10::ScalarType::Float8_e5m2:
8585
return CUDA_R_8F_E5M2;
8686
#endif
8787
#if defined(USE_ROCM)
88-
#if defined(HIP_NEW_TYPE_ENUMS)
8988
case c10::ScalarType::Float8_e4m3fnuz:
9089
return HIP_R_8F_E4M3_FNUZ;
9190
case c10::ScalarType::Float8_e5m2fnuz:
9291
return HIP_R_8F_E5M2_FNUZ;
93-
#else
94-
case c10::ScalarType::Float8_e4m3fnuz:
95-
return static_cast<hipDataType>(1000);
96-
case c10::ScalarType::Float8_e5m2fnuz:
97-
return static_cast<hipDataType>(1001);
98-
#endif
9992
#endif
10093
default:
10194
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

+34-7
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,65 @@
2626
namespace at::cuda::tunable {
2727

2828
template <typename T>
29-
constexpr hipblasDatatype_t HipDataTypeFor();
29+
constexpr hipDataType HipDataTypeFor();
3030

3131
template <>
32-
constexpr hipblasDatatype_t HipDataTypeFor<float>() {
32+
constexpr hipDataType HipDataTypeFor<float>() {
3333
return HIP_R_32F;
3434
}
3535

3636
template <>
37-
constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
37+
constexpr hipDataType HipDataTypeFor<Half>() {
3838
return HIP_R_16F;
3939
}
4040

4141
template <>
42-
constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
42+
constexpr hipDataType HipDataTypeFor<BFloat16>() {
4343
return HIP_R_16BF;
4444
}
4545

4646
template <>
47-
constexpr hipblasDatatype_t HipDataTypeFor<double>() {
47+
constexpr hipDataType HipDataTypeFor<double>() {
4848
return HIP_R_64F;
4949
}
5050

5151
template <>
52-
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
52+
constexpr hipDataType HipDataTypeFor<c10::Float8_e4m3fnuz>() {
5353
return HIP_R_8F_E4M3_FNUZ;
5454
}
5555

5656
template <>
57-
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
57+
constexpr hipDataType HipDataTypeFor<c10::Float8_e5m2fnuz>() {
5858
return HIP_R_8F_E5M2_FNUZ;
5959
}
6060

61+
// This code is instantiated regardless of ROCm version.
62+
// Prior to ROCm 6.3, we hard-code the known enum values.
63+
template <>
64+
constexpr hipDataType HipDataTypeFor<c10::Float8_e4m3fn>() {
65+
#if ROCM_VERSION >= 60300
66+
return HIP_R_8F_E4M3;
67+
#else
68+
return static_cast<hipDataType>(28);
69+
#endif
70+
}
71+
72+
template <>
73+
constexpr hipDataType HipDataTypeFor<c10::Float8_e5m2>() {
74+
#if ROCM_VERSION >= 60300
75+
return HIP_R_8F_E5M2;
76+
#else
77+
return static_cast<hipDataType>(29);
78+
#endif
79+
}
80+
81+
// This type is not intended for matrix types but rather a scale factor.
82+
// Return a dummy value to satisfy linker.
83+
template <>
84+
constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
85+
return static_cast<hipDataType>(500);
86+
}
87+
6188
template <typename T>
6289
int GetBatchFromParams(const GemmParams<T>* params) {
6390
return 1;

aten/src/ATen/cuda/tunable/TunableGemm.h

+6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <c10/util/Float8_e4m3fnuz.h>
2222
#include <c10/util/Float8_e5m2.h>
2323
#include <c10/util/Float8_e5m2fnuz.h>
24+
#include <c10/util/Float8_e8m0fnu.h>
2425
#include <c10/util/StringUtil.h>
2526
#include <fmt/printf.h>
2627

@@ -181,6 +182,11 @@ inline const char* TypeName(Float8_e5m2fnuz v) {
181182
return "Float8_e5m2fnuz";
182183
}
183184

185+
template <>
186+
inline const char* TypeName(Float8_e8m0fnu v) {
187+
return "Float8_e8m0fnu";
188+
}
189+
184190
template <>
185191
inline const char* TypeName(c10::complex<double> v) {
186192
return "c10::complex<double>";

aten/src/ATen/native/cuda/Blas.cpp

+43-4
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,12 @@ static bool isSupportedHipLtROCmArch(int index) {
191191
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
192192
std::string device_arch = prop->gcnArchName;
193193
static const std::vector<std::string> archs = {
194-
"gfx90a", "gfx940", "gfx941", "gfx942",
194+
"gfx90a", "gfx942"
195195
#if ROCM_VERSION >= 60300
196-
"gfx1100", "gfx1101"
196+
, "gfx1100", "gfx1101", "gfx1200", "gfx1201"
197+
#endif
198+
#if ROCM_VERSION >= 60500
199+
, "gfx950"
197200
#endif
198201
};
199202
for (std::string arch : archs) {
@@ -862,7 +865,15 @@ static bool _scaled_mm_allowed_device() {
862865
auto dprops = at::cuda::getCurrentDeviceProperties();
863866
#ifdef USE_ROCM
864867
std::string device_arch = dprops->gcnArchName;
865-
static const std::vector<std::string> archs = {"gfx940", "gfx941", "gfx942"};
868+
static const std::vector<std::string> archs = {
869+
"gfx942"
870+
#if ROCM_VERSION >= 60300
871+
,"gfx1200", "gfx1201"
872+
#endif
873+
#if ROCM_VERSION >= 60500
874+
,"gfx950"
875+
#endif
876+
};
866877
for (std::string arch : archs) {
867878
size_t substring = device_arch.find(arch);
868879
if (substring != std::string::npos) {
@@ -1144,6 +1155,34 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
11441155
BLASOP_A, BLASOP_B> scaledgemm{}; \
11451156
scaledgemm(&params); \
11461157
} \
1158+
} \
1159+
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
1160+
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
1161+
static at::cuda::tunable::ScaledGemmTunableOp< \
1162+
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
1163+
BLASOP_A, BLASOP_B> scaledgemm{}; \
1164+
scaledgemm(&params); \
1165+
} \
1166+
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
1167+
static at::cuda::tunable::ScaledGemmTunableOp< \
1168+
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
1169+
BLASOP_A, BLASOP_B> scaledgemm{}; \
1170+
scaledgemm(&params); \
1171+
} \
1172+
} \
1173+
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
1174+
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
1175+
static at::cuda::tunable::ScaledGemmTunableOp< \
1176+
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
1177+
BLASOP_A, BLASOP_B> scaledgemm{}; \
1178+
scaledgemm(&params); \
1179+
} \
1180+
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
1181+
static at::cuda::tunable::ScaledGemmTunableOp< \
1182+
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
1183+
BLASOP_A, BLASOP_B> scaledgemm{}; \
1184+
scaledgemm(&params); \
1185+
} \
11471186
}
11481187
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
11491188
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
@@ -1186,7 +1225,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
11861225
TORCH_CHECK(false, "unreachable");
11871226
}
11881227
}),
1189-
kHalf, kBFloat16, kFloat8_e4m3fnuz, kFloat8_e5m2fnuz, AT_EXPAND(AT_FLOATING_TYPES));
1228+
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
11901229
#undef TUNABLE_DISPATCH
11911230
}
11921231
else

cmake/Dependencies.cmake

+2-2
Original file line numberDiff line numberDiff line change
@@ -1026,8 +1026,8 @@ if(USE_ROCM)
10261026
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
10271027
list(APPEND HIP_CXX_FLAGS -D_GLIBCXX_USE_CXX11_ABI=${GLIBCXX_USE_CXX11_ABI})
10281028
list(APPEND HIP_CXX_FLAGS -DHIP_ENABLE_WARP_SYNC_BUILTINS)
1029-
if(HIP_NEW_TYPE_ENUMS)
1030-
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
1029+
if(WIN32)
1030+
add_definitions(-DROCM_ON_WINDOWS)
10311031
endif()
10321032
add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
10331033
add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})

cmake/public/LoadHIP.cmake

-27
Original file line numberDiff line numberDiff line change
@@ -175,34 +175,7 @@ if(HIP_FOUND)
175175
# roctx is part of roctracer
176176
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
177177

178-
# check whether HIP declares new types
179178
set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
180-
set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
181-
file(WRITE ${file} ""
182-
"#include <hip/library_types.h>\n"
183-
"int main() {\n"
184-
" hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n"
185-
" return 0;\n"
186-
"}\n"
187-
)
188-
189-
try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
190-
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
191-
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
192-
OUTPUT_VARIABLE hip_compile_output)
193-
194-
if(hip_compile_result)
195-
set(HIP_NEW_TYPE_ENUMS ON)
196-
#message("HIP is using new type enums: ${hip_compile_output}")
197-
message("HIP is using new type enums")
198-
else()
199-
set(HIP_NEW_TYPE_ENUMS OFF)
200-
#message("HIP is NOT using new type enums: ${hip_compile_output}")
201-
message("HIP is NOT using new type enums")
202-
endif()
203-
else() # Win32
204-
# With HIP-SDK 6.2, HIP declares new enum types on Windows
205-
set(HIP_NEW_TYPE_ENUMS ON)
206179
endif()
207180

208181
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")

test/test_linalg.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,13 @@
4848
def blaslt_supported_device():
4949
if torch.cuda.is_available():
5050
if torch.version.hip:
51-
for arch in ['gfx90a', 'gfx94']:
51+
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
52+
archs = ['gfx90a', 'gfx94']
53+
if ROCM_VERSION >= (6, 3):
54+
archs.extend(['gfx110', 'gfx120'])
55+
if ROCM_VERSION >= (6, 5):
56+
archs.append('gfx95')
57+
for arch in archs:
5258
if arch in torch.cuda.get_device_properties(0).gcnArchName:
5359
return True
5460
else:

test/test_matmul_cuda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _expand_to_batch(t: torch.Tensor):
214214

215215
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
216216

217-
if torch.version.hip:
217+
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
218218
e4m3_type = torch.float8_e4m3fnuz
219219
e5m2_type = torch.float8_e5m2fnuz
220220
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max

torch/_utils_internal.py

+4
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,16 @@ def max_clock_rate():
225225
return 1700
226226
elif "gfx908" in gcn_arch:
227227
return 1502
228+
elif "gfx12" in gcn_arch:
229+
return 1700
228230
elif "gfx11" in gcn_arch:
229231
return 1700
230232
elif "gfx103" in gcn_arch:
231233
return 1967
232234
elif "gfx101" in gcn_arch:
233235
return 1144
236+
elif "gfx95" in gcn_arch:
237+
return 1700 # TODO: placeholder, get actual value
234238
else:
235239
return 1100
236240

torch/testing/_internal/common_cuda.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,15 @@ def evaluate_platform_supports_cudnn_attention():
8282
def evaluate_platform_supports_fp8():
8383
if torch.cuda.is_available():
8484
if torch.version.hip:
85-
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
85+
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
86+
archs = ['gfx94']
87+
if ROCM_VERSION >= (6, 3):
88+
archs.extend(['gfx120'])
89+
if ROCM_VERSION >= (6, 5):
90+
archs.append('gfx95')
91+
for arch in archs:
92+
if arch in torch.cuda.get_device_properties(0).gcnArchName:
93+
return True
8694
else:
8795
return SM90OrLater or torch.cuda.get_device_capability() == (8, 9)
8896
return False

torch/utils/hipify/cuda_to_hip_mappings.py

+2
Original file line numberDiff line numberDiff line change
@@ -3863,6 +3863,8 @@
38633863
("CUDA_C_64I", ("HIP_C_64I", CONV_TYPE, API_RUNTIME)),
38643864
("CUDA_R_64U", ("HIP_R_64U", CONV_TYPE, API_RUNTIME)),
38653865
("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)),
3866+
("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)),
3867+
("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)),
38663868
(
38673869
"MAJOR_VERSION",
38683870
("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED),

0 commit comments

Comments
 (0)