Skip to content

Commit 5dc9128

Browse files
drisspgpytorchmergebot
authored andcommitted
FP8 rowwise scaling (pytorch#125204)
# Summary This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met: - `x`'s scale should be a 1-dimensional tensor of length `M`. - `y`'s scale should be a 1-dimensional tensor of length `N`. It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row". The following two PRs were required to enable local builds: - [PR pytorch#126185](pytorch#126185) - [PR pytorch#125523](pytorch#125523) ### Todo We still do not build our Python wheels with this architecture. @ptrblck @malfet, should we replace `sm_90` with `sm_90a`? The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit: https://github.com/pytorch/pytorch/pull/125204/files#r1586986954 #### ifdef I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this Kernel Credit: @jwfromm Pull Request resolved: pytorch#125204 Approved by: https://github.com/lw, https://github.com/malfet
1 parent 4f9fcd7 commit 5dc9128

File tree

8 files changed

+855
-25
lines changed

8 files changed

+855
-25
lines changed

aten/src/ATen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ endif()
473473

474474
if(USE_CUDA AND NOT USE_ROCM)
475475
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
476+
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
476477
if($ENV{ATEN_STATIC_CUDA})
477478
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
478479
${CUDA_LIBRARIES}

aten/src/ATen/cuda/detail/LazyNVRTC.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,43 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);
170170
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int);
171171
CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction);
172172

173+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
174+
CUresult CUDAAPI
175+
cuTensorMapEncodeTiled(
176+
CUtensorMap* tensorMap,
177+
CUtensorMapDataType tensorDataType,
178+
cuuint32_t tensorRank,
179+
void* globalAddress,
180+
const cuuint64_t* globalDim,
181+
const cuuint64_t* globalStrides,
182+
const cuuint32_t* boxDim,
183+
const cuuint32_t* elementStrides,
184+
CUtensorMapInterleave interleave,
185+
CUtensorMapSwizzle swizzle,
186+
CUtensorMapL2promotion l2Promotion,
187+
CUtensorMapFloatOOBfill oobFill) {
188+
auto fn = reinterpret_cast<decltype(&cuTensorMapEncodeTiled)>(
189+
getCUDALibrary().sym(__func__));
190+
if (!fn)
191+
throw std::runtime_error("Can't get cuTensorMapEncodeTiled");
192+
lazyNVRTC.cuTensorMapEncodeTiled = fn;
193+
return fn(
194+
tensorMap,
195+
tensorDataType,
196+
tensorRank,
197+
globalAddress,
198+
globalDim,
199+
globalStrides,
200+
boxDim,
201+
elementStrides,
202+
interleave,
203+
swizzle,
204+
l2Promotion,
205+
oobFill);
206+
}
207+
208+
#endif
209+
173210
// Irregularly shaped functions
174211
CUresult CUDAAPI cuLaunchKernel(CUfunction f,
175212
unsigned int gridDimX,

aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,25 @@ namespace at { namespace cuda {
5959
_(cuLinkAddData) \
6060
_(cuLinkComplete) \
6161
_(cuFuncSetAttribute) \
62-
_(cuFuncGetAttribute)
62+
_(cuFuncGetAttribute) \
63+
64+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
65+
#define AT_FORALL_NVRTC_EXTENDED(_) \
66+
AT_FORALL_NVRTC_BASE(_) \
67+
_(cuTensorMapEncodeTiled)
68+
#else
69+
#define AT_FORALL_NVRTC_EXTENDED(_) \
70+
AT_FORALL_NVRTC_BASE(_)
71+
#endif
6372

6473
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
6574
#define AT_FORALL_NVRTC(_) \
66-
AT_FORALL_NVRTC_BASE(_) \
75+
AT_FORALL_NVRTC_EXTENDED(_) \
6776
_(nvrtcGetCUBINSize) \
6877
_(nvrtcGetCUBIN)
6978
#else
7079
#define AT_FORALL_NVRTC(_) \
71-
AT_FORALL_NVRTC_BASE(_)
80+
AT_FORALL_NVRTC_EXTENDED(_)
7281
#endif
7382

7483
#else

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#include <cstdint>
2+
#include <c10/util/Exception.h>
3+
#include <c10/core/Scalar.h>
4+
#include <c10/core/ScalarType.h>
15
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
26
#include <ATen/core/Tensor.h>
37
#include <ATen/core/NamedTensor.h>
@@ -10,6 +14,7 @@
1014
#include <ATen/cuda/tunable/TunableGemm.h>
1115
#include <ATen/native/Resize.h>
1216
#include <c10/util/MaybeOwned.h>
17+
#include <ATen/native/cuda/RowwiseScaledMM.h>
1318

1419
#ifndef AT_PER_OPERATOR_HEADERS
1520
#include <ATen/Functions.h>
@@ -819,24 +824,97 @@ static bool _scaled_mm_allowed_device() {
819824
#endif
820825
}
821826

827+
namespace{
828+
829+
enum class ScalingType {
830+
TensorWise,
831+
RowWise,
832+
Error
833+
};
834+
835+
// Validates the scale tensors to scaled_mm
836+
// And returns the type of scaling/which kernel to use
837+
ScalingType get_scaling_type(
838+
const c10::optional<at::Tensor>& scale_a,
839+
const c10::optional<at::Tensor>& scale_b,
840+
int64_t dim_m,
841+
int64_t dim_n) {
842+
TORCH_CHECK(
843+
scale_a.has_value() == scale_b.has_value(),
844+
"Both scale_a and scale_b must be present or absent.");
845+
846+
if (scale_a.has_value()) {
847+
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
848+
TORCH_CHECK(
849+
scale_a->scalar_type() == kFloat && scale_b->scalar_type() == kFloat,
850+
"Both scale_a and scale_b must be float (fp32) tensors.");
851+
852+
// Check the singluar scale case for per-tensor scaling
853+
if (scale_a->numel() == 1 && scale_b->numel() == 1) {
854+
return ScalingType::TensorWise;
855+
} else if (scale_a->dim() == 1 && scale_a->size(0) == dim_m) {
856+
// Check the per-row scaling case
857+
#if !defined(USE_ROCM) && !defined(_MSC_VER) || \
858+
(defined(USE_ROCM) && ROCM_VERSION >= 60000)
859+
TORCH_CHECK(
860+
scale_a->dim() == 1 && scale_b->dim() == 1,
861+
"Both scale_a and scale_b must be 1-dimensional tensors");
862+
TORCH_CHECK(
863+
scale_b->size(0) == dim_n,
864+
"For row-wise scaling, scale_b must have size ",
865+
dim_n,
866+
" but got ",
867+
scale_b->size(0),
868+
".");
869+
TORCH_CHECK(
870+
scale_a->is_contiguous() && scale_b->is_contiguous(),
871+
"Both scale_a and scale_b must be contiguous.");
872+
return ScalingType::RowWise;
873+
#else
874+
TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
875+
return ScalingType::Error;
876+
#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) &&
877+
// ROCM_VERSION >= 60000)
878+
} else {
879+
TORCH_CHECK(
880+
false,
881+
"For row-wise scaling, scale_a must be size ",
882+
dim_m,
883+
" but got ",
884+
scale_a->numel(),
885+
" and scale_b must be size ",
886+
dim_n,
887+
" but got ",
888+
scale_b->numel(),
889+
".");
890+
// Unreachable
891+
return ScalingType::RowWise;
892+
}
893+
}
894+
return ScalingType::Error;
895+
}
896+
897+
} // namespace
898+
822899
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
823900
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
824901
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
825902
// Known limitations:
826903
// - Only works if mat1 is row-major and mat2 is column-major
827904
// - Only works if matrices sizes are divisible by 32
828-
//
905+
// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0)
906+
// and scale_b should have size = to mat2.size(1)
829907
// Arguments:
830908
// - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
831909
// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
832910
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
833911
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
834-
// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
835-
// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
836-
// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type
912+
// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
913+
// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
914+
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
837915
// - `use_fast_accum`: if true, enables fast float8 accumulation
838916
// - `out`: a reference to the output tensor
839-
// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace
917+
// - `amax`: a reference to the amax tensor of the output, only mutated if the output is a float8 type and will be updated inplace
840918

841919
std::tuple<Tensor&, Tensor&>
842920
_scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
@@ -855,10 +933,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
855933
TORCH_CHECK(
856934
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
857935
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
858-
TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat),
859-
"scale_a must be float scalar");
860-
TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat),
861-
"scale_b must be a float scalar");
936+
937+
// Check what type of scaling we are doing based on inputs
938+
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
939+
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
940+
862941
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
863942
"scale_result must be a float scalar");
864943
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
@@ -901,12 +980,26 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
901980
{scale_result_, "scale_result", 7}};
902981
checkAllSameGPU(__func__, targs);
903982
}
904-
983+
// Validation checks have passed lets resize the output to actual size
905984
IntArrayRef mat1_sizes = mat1.sizes();
906985
IntArrayRef mat2_sizes = mat2.sizes();
907986
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
908987
at::native::resize_output(amax, {});
909988

989+
// We are doing row-wise scaling
990+
if (scaling_choice == ScalingType::RowWise) {
991+
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");
992+
at::cuda::detail::f8f8bf16_rowwise(
993+
mat1,
994+
mat2,
995+
scale_a.value(),
996+
scale_b.value(),
997+
bias,
998+
use_fast_accum,
999+
out);
1000+
return {out, amax};
1001+
}
1002+
9101003
cublasCommonArgs args(mat1, mat2, out);
9111004
const auto out_dtype_ = args.result->scalar_type();
9121005
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");

0 commit comments

Comments
 (0)