25
25
26
26
from torchao .prototype .mx_formats .config import MXGemmKernelChoice
27
27
from torchao .prototype .mx_formats .constants import (
28
- BF16_EXP_BIAS ,
29
28
BLOCK_SIZE_DEFAULT ,
30
29
DTYPE_FP6_E2M3 ,
31
30
DTYPE_FP6_E3M2 ,
62
61
63
62
# TODO(later): read from somewhere else?
64
63
SBITS , EBITS_F32 , MBITS_F32 = 1 , 8 , 23
65
- EBITS_BF16 , MBITS_BF16 = 8 , 7
66
64
EBITS_F4_E2M1 , MBITS_F4_E2M1 = 2 , 1
67
65
EBITS_F6_E2M3 , MBITS_F6_E2M3 = 2 , 3
68
66
EBITS_F6_E3M2 , MBITS_F6_E3M2 = 3 , 2
@@ -137,9 +135,7 @@ def _to_mx_rceil(
137
135
)
138
136
139
137
# scale and saturated cast the data elements to max of target dtype
140
- data_lp = torch .clamp (
141
- data_hp * descale_fp .unsqueeze (1 ), min = - 1 * max_pos , max = max_pos
142
- )
138
+ data_lp = torch .clamp (data_hp * descale_fp , min = - 1 * max_pos , max = max_pos )
143
139
return exponent , data_lp
144
140
145
141
@@ -160,22 +156,33 @@ def to_mx(
160
156
torch .float ,
161
157
), f"{ data_hp .dtype } is not supported yet"
162
158
# TODO(future PR): consider supporting padding
163
- assert data_hp .numel () % block_size == 0 , "unsupported"
159
+ assert data_hp .shape [- 1 ] % block_size == 0 , (
160
+ f"the last dimension of shape { data_hp .shape } must be divisible by block_size { block_size } "
161
+ )
164
162
assert data_hp .is_contiguous (), "unsupported"
165
163
assert elem_dtype in SUPPORTED_ELEM_DTYPES , "unsupported"
166
164
167
- # calculate the scale in e8m0 format
168
-
169
165
orig_shape = data_hp .shape
170
- # TODO(future PR): fix this line for TP, currently this reshape does not work
171
- # for rank 3 tensor where dim1 is sharded
172
- data_hp = data_hp . reshape ( - 1 , block_size )
166
+ data_hp = data_hp . reshape (
167
+ * orig_shape [: - 1 ], orig_shape [ - 1 ] // block_size , block_size
168
+ )
173
169
174
170
# find max value of the data
175
171
# Note: this only implements the `minimally supported` version of
176
172
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
177
173
# section 6.3.
178
- max_abs = torch .amax (torch .abs (data_hp ), 1 )
174
+ max_abs = torch .amax (torch .abs (data_hp ), - 1 ).unsqueeze (- 1 )
175
+
176
+ # We cast to float32 here because
177
+ # in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below,
178
+ # if tensor parallel is enabled then the resulting shape is 2x larger
179
+ # than it should be under some conditions, likely because of a bug in
180
+ # the `view` op with DTensor and target dtype int16. I reproduce in
181
+ # torchtitan but not in a unit test, so not enough info to file a good
182
+ # issue in pytorch/pytorch. For now, work around. In the future we should
183
+ # debug and fix this properly.
184
+ data_hp = data_hp .to (torch .float32 )
185
+ max_abs = max_abs .to (torch .float32 )
179
186
180
187
# Set X to be the largest power-of-two less than or equal to
181
188
# max_abs(v), divided by the largest power of two representable
@@ -206,17 +213,11 @@ def to_mx(
206
213
if scaling_mode == ScaleCalculationMode .RCEIL :
207
214
scale_e8m0_biased , data_lp = _to_mx_rceil (data_hp , max_abs , max_pos )
208
215
else :
209
- if data_hp .dtype is torch .float32 :
210
- hp_int_dtype = torch .int32
211
- hp_mbits = MBITS_F32
212
- hp_ebits = EBITS_F32
213
- hp_exp_bias = F32_EXP_BIAS
214
- else :
215
- assert data_hp .dtype is torch .bfloat16
216
- hp_int_dtype = torch .int16
217
- hp_mbits = MBITS_BF16
218
- hp_ebits = EBITS_BF16
219
- hp_exp_bias = BF16_EXP_BIAS
216
+ assert data_hp .dtype is torch .float32
217
+ hp_int_dtype = torch .int32
218
+ hp_mbits = MBITS_F32
219
+ hp_ebits = EBITS_F32
220
+ hp_exp_bias = F32_EXP_BIAS
220
221
221
222
# rounding before calculating the largest power of 2
222
223
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
@@ -285,7 +286,7 @@ def to_mx(
285
286
scale_fp32 = torch .clamp (scale_fp32 , min = F32_MIN_NORMAL )
286
287
287
288
# scale and saturated cast the data elements to max of target dtype
288
- data_lp = data_hp / scale_fp32 . unsqueeze ( 1 )
289
+ data_lp = data_hp / scale_fp32
289
290
290
291
if (
291
292
elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 )
@@ -511,7 +512,6 @@ def __new__(
511
512
assert scale_e8m0_bits .dtype == torch .float8_e8m0fnu , (
512
513
f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got { scale_e8m0_bits .dtype } "
513
514
)
514
- assert len (scale_e8m0_bits .shape ) == 1 , "unsupported"
515
515
assert data_bits .dtype in (
516
516
torch .float8_e4m3fn ,
517
517
torch .float8_e5m2 ,
0 commit comments