Skip to content

Commit bf084c3

Browse files
vkuzopruthvistony
authored andcommitted
add the torch.float8_e8m0fnu dtype to PyTorch (pytorch#147466)
Summary: Continuing the work from pytorch#146427 Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in pytorch#146414 . Please see the issue for a detailed definition of the format. Example of basic functionality: ```python import torch # round trip x0 = torch.randn(4, 4, dtype=torch.float32) x1 = x0.to(torch.float8_e8m0fnu) # RNE rounding x2 = x1.to(torch.float32) # 2 ** exponent # creation with empty x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu) # printing print(x0) ``` Done in this PR: * numerical correctness * op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32 * printing a tensor works For future PRs: * performance optimizations for casting * torch._scaled_mm * PT2 * various cleanups (detailed in comments with issue numbers) Test Plan: ``` pytest test/quantization/core/experimental/test_float8.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#147466 Approved by: https://github.com/drisspg
1 parent ed8c660 commit bf084c3

25 files changed

+535
-44
lines changed

aten/src/ATen/DLConvertor.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ DLDataType getDLDataType(const Tensor& t) {
6363
case ScalarType::BFloat16:
6464
dtype.code = DLDataTypeCode::kDLBfloat;
6565
break;
66+
// TODO(#146647): use macro here instead of spelling out each shell dtype
6667
case ScalarType::Float8_e5m2:
6768
case ScalarType::Float8_e5m2fnuz:
6869
case ScalarType::Float8_e4m3fn:
6970
case ScalarType::Float8_e4m3fnuz:
71+
case ScalarType::Float8_e8m0fnu:
7072
TORCH_CHECK(false, "float8 types are not supported by dlpack");
7173
break;
7274
case ScalarType::QInt8:

aten/src/ATen/Dispatch_v2.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787

8888
#define AT_FLOAT8_TYPES \
8989
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
90-
c10::kFloat8_e4m3fnuz
90+
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
9191

9292
#define AT_INTEGRAL_TYPES \
9393
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort

aten/src/ATen/native/Copy.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
5959
#if !defined(C10_MOBILE)
6060
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
6161
AT_DISPATCH_V2( \
62-
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
63-
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
62+
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \
63+
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
6464
#else
6565
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
6666
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/TensorCompare.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ Tensor isinf(const Tensor &self) {
403403

404404
Tensor isfinite(const Tensor& self) {
405405
// Note: Integral tensor values are always finite
406-
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
406+
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
407+
self.scalar_type() == kFloat8_e8m0fnu) {
407408
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
408409
}
409410

aten/src/ATen/native/cpu/CopyKernel.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
204204
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
205205
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
206206
kComplexHalf, kHalf, kBool, \
207-
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
208-
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
207+
kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
208+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
209209
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
210210
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
211-
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
212-
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
211+
kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
212+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
213213
#else
214214
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
215215
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/cpu/FillKernel.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
5151
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
5252
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
5353
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
54+
} else if (iter.dtype() == ScalarType::Float8_e8m0fnu) {
55+
// TODO(#146647): use macro here instead of spelling out each float8 dtype
56+
fill_non_native_type<at::Float8_e8m0fnu>(iter, value_scalar);
5457
} else {
5558
AT_DISPATCH_V2(
5659
iter.dtype(), "fill_cpu", AT_WRAP([&]() {

aten/src/ATen/native/cpu/IndexKernel.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,13 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
184184
}
185185
}),
186186
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
187-
AT_EXPAND(AT_FLOAT8_TYPES),
187+
// AT_EXPAND(AT_FLOAT8_TYPES),
188+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
189+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
190+
kFloat8_e4m3fn,
191+
kFloat8_e5m2,
192+
kFloat8_e4m3fnuz,
193+
kFloat8_e5m2fnuz,
188194
kComplexHalf,
189195
kHalf,
190196
kBool,

aten/src/ATen/native/cuda/Copy.cu

+23-1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,28 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
144144
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
145145
break;
146146
}
147+
} else if (dtype == kFloat8_e8m0fnu) {
148+
// TODO(#146647): clean this up, too much copy-pasta
149+
switch (other_dtype) {
150+
case kFloat:
151+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
152+
return Float8_e8m0fnu(value);
153+
});
154+
break;
155+
case kHalf:
156+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
157+
return Float8_e8m0fnu(value);
158+
});
159+
break;
160+
case kBFloat16:
161+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
162+
return Float8_e8m0fnu(value);
163+
});
164+
break;
165+
default:
166+
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e8m0fnu x) { return x; });
167+
break;
168+
}
147169
} else {
148170
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
149171
}
@@ -157,7 +179,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
157179
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
158180
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
159181
});
160-
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
182+
} else if (isFloat8Type(dtype)) {
161183
float8_copy_kernel_cuda(iter);
162184
} else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) {
163185
if (dtype == kBFloat16) {

aten/src/ATen/native/cuda/Indexing.cu

+35-5
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
712712
C10_CUDA_KERNEL_LAUNCH_CHECK();
713713
}),
714714
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
715-
AT_EXPAND(AT_FLOAT8_TYPES),
715+
// AT_EXPAND(AT_FLOAT8_TYPES),
716+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
717+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
718+
kFloat8_e4m3fn,
719+
kFloat8_e5m2,
720+
kFloat8_e4m3fnuz,
721+
kFloat8_e5m2fnuz,
716722
kComplexHalf,
717723
kHalf,
718724
kBool,
@@ -738,7 +744,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
738744
C10_CUDA_KERNEL_LAUNCH_CHECK();
739745
}),
740746
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
741-
AT_EXPAND(AT_FLOAT8_TYPES),
747+
// AT_EXPAND(AT_FLOAT8_TYPES),
748+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
749+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
750+
kFloat8_e4m3fn,
751+
kFloat8_e5m2,
752+
kFloat8_e4m3fnuz,
753+
kFloat8_e5m2fnuz,
742754
kComplexHalf,
743755
kHalf,
744756
kBool,
@@ -762,7 +774,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
762774
C10_CUDA_KERNEL_LAUNCH_CHECK();
763775
}),
764776
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
765-
AT_EXPAND(AT_FLOAT8_TYPES),
777+
// AT_EXPAND(AT_FLOAT8_TYPES),
778+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
779+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
780+
kFloat8_e4m3fn,
781+
kFloat8_e5m2,
782+
kFloat8_e4m3fnuz,
783+
kFloat8_e5m2fnuz,
766784
kComplexHalf,
767785
kHalf,
768786
kBool,
@@ -784,7 +802,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
784802
C10_CUDA_KERNEL_LAUNCH_CHECK();
785803
}),
786804
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
787-
AT_EXPAND(AT_FLOAT8_TYPES),
805+
// AT_EXPAND(AT_FLOAT8_TYPES),
806+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
807+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
808+
kFloat8_e4m3fn,
809+
kFloat8_e5m2,
810+
kFloat8_e4m3fnuz,
811+
kFloat8_e5m2fnuz,
788812
kComplexHalf,
789813
kHalf,
790814
kBool,
@@ -809,7 +833,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
809833
C10_CUDA_KERNEL_LAUNCH_CHECK();
810834
}),
811835
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
812-
AT_EXPAND(AT_FLOAT8_TYPES),
836+
// AT_EXPAND(AT_FLOAT8_TYPES),
837+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
838+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
839+
kFloat8_e4m3fn,
840+
kFloat8_e5m2,
841+
kFloat8_e4m3fnuz,
842+
kFloat8_e5m2fnuz,
813843
kComplexHalf,
814844
kHalf,
815845
kBool,

aten/src/ATen/native/cuda/jit_utils.h

+4
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
225225
template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
226226
return "at::Float8_e4m3fnuz";
227227
}
228+
template <> inline std::string typeName<at::Float8_e8m0fnu>() {
229+
// TODO(#146647): Can the code here be made generic for any scalartype?
230+
return "at::Float8_e8m0fnu";
231+
}
228232

229233
#define TYPE_NAME_CASE(ctype, scalartype) \
230234
case ScalarType::scalartype: return typeName<ctype>();

c10/core/Scalar.h

+2-9
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,9 @@ class C10_API Scalar {
4949
#define DEFINE_IMPLICIT_CTOR(type, name) \
5050
Scalar(type vv) : Scalar(vv, true) {}
5151

52-
AT_FORALL_SCALAR_TYPES_AND7(
53-
Half,
54-
BFloat16,
55-
Float8_e5m2,
56-
Float8_e4m3fn,
57-
Float8_e5m2fnuz,
58-
Float8_e4m3fnuz,
59-
ComplexHalf,
60-
DEFINE_IMPLICIT_CTOR)
52+
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
6153
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
54+
AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR)
6255

6356
// Helper constructors to allow Scalar creation from long and long long types
6457
// As std::is_same_v<long, long long> is false(except Android), one needs to

c10/core/ScalarType.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
222222
return std::make_pair("float8_e5m2fnuz", "");
223223
case c10::ScalarType::Float8_e4m3fnuz:
224224
return std::make_pair("float8_e4m3fnuz", "");
225+
case c10::ScalarType::Float8_e8m0fnu:
226+
// TODO(#146647): macroify all of this
227+
return std::make_pair("float8_e8m0fnu", "");
225228
default:
226229
throw std::runtime_error("Unimplemented scalar type");
227230
}

c10/core/ScalarType.h

+19-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <c10/util/Float8_e4m3fnuz.h>
88
#include <c10/util/Float8_e5m2.h>
99
#include <c10/util/Float8_e5m2fnuz.h>
10+
#include <c10/util/Float8_e8m0fnu.h>
1011
#include <c10/util/Half.h>
1112
#include <c10/util/bits.h>
1213
#include <c10/util/complex.h>
@@ -102,7 +103,8 @@ struct dummy_int1_7_t {};
102103
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
103104
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
104105
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
105-
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */
106+
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
107+
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */
106108

107109
// If you want to support ComplexHalf for real, add ComplexHalf
108110
// into this macro (and change the name). But beware: convert()
@@ -146,7 +148,8 @@ struct dummy_int1_7_t {};
146148
_(at::Float8_e5m2, Float8_e5m2) \
147149
_(at::Float8_e4m3fn, Float8_e4m3fn) \
148150
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
149-
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz)
151+
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
152+
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
150153

151154
enum class ScalarType : int8_t {
152155
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
@@ -317,6 +320,13 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
317320
_(c10::quint4x2, QUInt4x2) \
318321
_(c10::quint2x4, QUInt2x4)
319322

323+
#define AT_FORALL_FLOAT8_TYPES(_) \
324+
_(at::Float8_e5m2, Float8_e5m2) \
325+
_(at::Float8_e4m3fn, Float8_e4m3fn) \
326+
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
327+
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
328+
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
329+
320330
#define AT_FORALL_COMPLEX_TYPES(_) \
321331
_(c10::complex<float>, ComplexFloat) \
322332
_(c10::complex<double>, ComplexDouble)
@@ -372,7 +382,8 @@ inline bool isIntegralType(ScalarType t) {
372382

373383
inline bool isFloat8Type(ScalarType t) {
374384
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
375-
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz;
385+
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz ||
386+
t == ScalarType::Float8_e8m0fnu;
376387
}
377388

378389
inline bool isReducedFloatingType(ScalarType t) {
@@ -446,6 +457,10 @@ inline bool isSignedType(ScalarType t) {
446457
return std::numeric_limits< \
447458
::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
448459

460+
// TODO(#146647): If we expect to have numeric_limits for everything,
461+
// let's just have a big macro for the whole thing.
462+
// If we're hardcoding it, let's just use the macro and a "true"/"false"
463+
// below?
449464
switch (t) {
450465
case ScalarType::QInt8:
451466
case ScalarType::QUInt8:
@@ -467,6 +482,7 @@ inline bool isSignedType(ScalarType t) {
467482
CASE_ISSIGNED(Float8_e5m2fnuz);
468483
CASE_ISSIGNED(Float8_e4m3fn);
469484
CASE_ISSIGNED(Float8_e4m3fnuz);
485+
CASE_ISSIGNED(Float8_e8m0fnu);
470486
CASE_ISSIGNED(Byte);
471487
CASE_ISSIGNED(Char);
472488
CASE_ISSIGNED(Short);

0 commit comments

Comments
 (0)