Skip to content

Commit dab139b

Browse files
committed
enable to_mxfp8 cast for DTensor
Summary: 1. add a test for casting a DTensor to mxfp8 2. make the test pass: a. remove addition of epsilon, it's not supported in DTensor world but we also don't need it anymore since we are no longer using `log2` anywhere. b. replace `<<` with `torch.bitwise_left_shift` and `>>` with `torch.bitwise_right_shift`. The short versions are silently broken for DTensor inputs, but the verbose versions work. 3. set up the wiring for testing mxfp8 with TP on a toy model. Note that making this work is split for the next PR, as this PR got too large. Test Plan: ```bash ./test/prototype/mx_formats/test_dtensor.sh ./test/float8/test_dtensor.sh pytest test/prototype/mx_formats/ ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d82a83f ghstack-comment-id: 2993264092 Pull Request resolved: #2420
1 parent 2af2241 commit dab139b

File tree

5 files changed

+169
-18
lines changed

5 files changed

+169
-18
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Test numerics of manually defined float16 TP vs mxfp8 TP of toy models
8+
9+
Note: for now, this does not run in CI.
10+
TODO(future): make this run in CI
11+
"""
12+
13+
import os
14+
15+
import pytest
16+
import torch
17+
18+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
19+
20+
if not TORCH_VERSION_AT_LEAST_2_7:
21+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
22+
23+
from torch.distributed._tensor import DTensor, Shard, distribute_tensor
24+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
25+
from tqdm import tqdm
26+
27+
from torchao.prototype.mx_formats import MXLinearConfig
28+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
29+
from torchao.testing.training.dtensor_utils import (
30+
_test_lowp_mlp_tensor_parallelism_base,
31+
)
32+
33+
torch.set_float32_matmul_precision("high")
34+
35+
36+
def setup_distributed():
37+
world_size = int(os.environ.get("WORLD_SIZE", -1))
38+
device_mesh = init_device_mesh("cuda", (world_size,))
39+
# seed must be the same in all processes
40+
torch.manual_seed(1)
41+
local_rank = torch.distributed.get_rank()
42+
torch.cuda.set_device(local_rank)
43+
return device_mesh
44+
45+
46+
def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
47+
device = mesh.device_type
48+
49+
x_fp32 = torch.rand(size, size, device=device)
50+
x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=size // 2)
51+
52+
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
53+
dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=size // 2)
54+
assert isinstance(dist_x_fp8, DTensor)
55+
56+
# Verify that the result of to_mx with DTensor matches the slice of the
57+
# result of to_mx without DTensor. This will fail on numeric op mismatches.
58+
local_rank = torch.distributed.get_rank()
59+
world_size = torch.distributed.get_world_size()
60+
assert size % world_size == 0, "unsupported"
61+
x_fp8_fp32 = x_fp8.to_dtype(torch.float32)
62+
rows_per_slice = size // world_size
63+
slice_start = local_rank * rows_per_slice
64+
slice_end = (local_rank + 1) * rows_per_slice
65+
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end]
66+
torch.testing.assert_close(
67+
x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0
68+
)
69+
70+
71+
def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
72+
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
73+
# TODO(future PR): assert that the K dim must be divisible by block size,
74+
# today this is silently incorrect if block_size is greater than K
75+
config.block_size = 16
76+
_test_lowp_mlp_tensor_parallelism_base(
77+
mesh, config, size, compile=False, allgather_in_lowp=False
78+
)
79+
80+
# TODO(future PR): compile
81+
82+
83+
if __name__ == "__main__":
84+
device_mesh = setup_distributed()
85+
tests = [
86+
_test_dtensor_cast_to_mxfp8,
87+
# TODO(next PR): enable this (current PR got too large, so splitting)
88+
# _test_mxfp8_mlp_tensor_parallelism_eager,
89+
]
90+
91+
for test in tqdm(tests, desc="Running tests"):
92+
try:
93+
test(device_mesh)
94+
except Exception as e:
95+
print(f"Test {test.__name__} failed with error: {e}")
96+
raise e
97+
98+
torch.distributed.destroy_process_group()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#!/bin/bash
7+
8+
# terminate script on first error
9+
set -e
10+
11+
if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
12+
echo "Skipping test_dtensor.sh because no CUDA devices are available."
13+
exit
14+
fi
15+
16+
# integration tests for TP/SP
17+
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_dtensor.py
18+
# NCCL_DEBUG=WARN torchrun --nproc_per_node 1 test/prototype/mx_formats/test_dtensor.py

torchao/prototype/mx_formats/kernels.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,15 +1102,12 @@ def _triton_calculate_scale(x, axis):
11021102
bf16_mbits = 7
11031103
bf16_exp_bias = 127
11041104
fp32_mbits = 23
1105-
# We use a small epsilon to avoid division by zero
1106-
epsilon = 1e-10
11071105

11081106
# Find the maximum absolute value for each row
11091107
max_abs = tl.max(x, axis=axis)
11101108

11111109
# Calculate the e8m0 scale by extracting the exponent (floor)
11121110
# TODO(future PR): support other exponent extraction types (ceil, RNE)
1113-
max_abs = max_abs + epsilon
11141111
max_abs = max_abs.to(tl.bfloat16)
11151112
max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
11161113
extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Callable, Dict, Union
2222

2323
import torch
24+
from torch.distributed._tensor import DTensor
2425

2526
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
2627
from torchao.prototype.mx_formats.constants import (
@@ -166,6 +167,8 @@ def to_mx(
166167
# calculate the scale in e8m0 format
167168

168169
orig_shape = data_hp.shape
170+
# TODO(future PR): fix this line for TP, currently this reshape does not work
171+
# for rank 3 tensor where dim1 is sharded
169172
data_hp = data_hp.reshape(-1, block_size)
170173

171174
# find max value of the data
@@ -174,10 +177,6 @@ def to_mx(
174177
# section 6.3.
175178
max_abs = torch.amax(torch.abs(data_hp), 1)
176179

177-
# Add an epsilon to prevent the log2 function call for returning -inf
178-
# where the values are zero.
179-
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
180-
181180
# Set X to be the largest power-of-two less than or equal to
182181
# max_abs(v), divided by the largest power of two representable
183182
# in the element data type, and get the mbits at the same time
@@ -233,8 +232,12 @@ def to_mx(
233232
)
234233

235234
# Calculate the scale for different modes
236-
max_abs_int32 = (max_abs + eps).view(hp_int_dtype)
237-
extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias
235+
max_abs_int32 = max_abs.view(hp_int_dtype)
236+
# For now, use `torch.bitwise_right_shift` instead of `>>` to support DTensor
237+
# See https://github.com/pytorch/pytorch/issues/156533.
238+
extracted_pow2 = (
239+
(torch.bitwise_right_shift(max_abs_int32, hp_mbits)) & 0b11111111
240+
) - hp_exp_bias
238241

239242
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
240243
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
@@ -266,9 +269,11 @@ def to_mx(
266269
)
267270

268271
# For now, calculate the scale in floating point.
269-
scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view(
270-
torch.float32
271-
)
272+
# For now, use `torch.bitwise_left_shift` instead of `<<` to support DTensor
273+
# See https://github.com/pytorch/pytorch/issues/156533.
274+
scale_fp32 = (
275+
torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32)
276+
).view(torch.float32)
272277

273278
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
274279
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -597,6 +602,28 @@ def to_mx(
597602
scale_e8m0_biased, data_lp = to_mx(
598603
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
599604
)
605+
if isinstance(scale_e8m0_biased, DTensor):
606+
assert isinstance(data_lp, DTensor), "unsupported"
607+
local_scale_e8m0_biased = scale_e8m0_biased.to_local()
608+
local_data_lp = data_lp.to_local()
609+
inner_mx_tensor = MXTensor(
610+
local_scale_e8m0_biased,
611+
local_data_lp,
612+
elem_dtype,
613+
block_size,
614+
data_hp.dtype,
615+
use_fp4_custom_triton_dequant_kernel,
616+
gemm_kernel_choice,
617+
pack_fp6,
618+
)
619+
return DTensor.from_local(
620+
inner_mx_tensor,
621+
data_lp.device_mesh,
622+
data_lp.placements,
623+
run_check=False,
624+
shape=data_lp.size(),
625+
stride=data_lp.stride(),
626+
)
600627
return MXTensor(
601628
scale_e8m0_biased,
602629
data_lp,

torchao/testing/training/dtensor_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
from typing import Union
78

89
import torch
910
import torch.nn as nn
@@ -24,6 +25,8 @@
2425
Float8RowwiseParallel,
2526
PrepareFloat8ModuleInput,
2627
)
28+
from torchao.prototype.mx_formats.config import MXLinearConfig
29+
from torchao.quantization import quantize_
2730

2831

2932
class FeedForward(nn.Module):
@@ -36,7 +39,9 @@ def __init__(self):
3639
self.out_proj = nn.Linear(32, 16, bias=False)
3740

3841
def forward(self, x):
39-
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))
42+
x = F.silu(self.w1(x)) * self.w2(x)
43+
x = self.out_proj(x)
44+
return x
4045

4146

4247
class ToyModel(nn.Module):
@@ -50,20 +55,26 @@ def forward(self, x):
5055

5156
def _test_lowp_mlp_tensor_parallelism_base(
5257
mesh: DeviceMesh,
53-
config: Float8LinearConfig,
58+
config: Union[Float8LinearConfig, MXLinearConfig],
5459
size=16,
5560
compile: bool = False,
5661
allgather_in_lowp: bool = False,
5762
):
5863
device = mesh.device_type
5964

65+
# TODO(future): remove this once float8 training works with `quantize_` API
66+
convert_model_func = convert_to_float8_training
67+
if isinstance(config, MXLinearConfig):
68+
convert_model_func = quantize_
69+
6070
toy_model = ToyModel().to(device)
61-
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)
71+
toy_model_fp8 = copy.deepcopy(toy_model)
72+
convert_model_func(toy_model_fp8, config=config)
6273

6374
tp_model = copy.deepcopy(toy_model)
64-
tp_model = convert_to_float8_training(tp_model, config=config)
75+
convert_model_func(tp_model, config=config)
6576
sp_model = copy.deepcopy(toy_model)
66-
sp_model = convert_to_float8_training(sp_model, config=config)
77+
convert_model_func(sp_model, config=config)
6778

6879
# For tensorwise scaling, enable float8 all_gather.
6980
# For rowwise scaling, keep high precision all_gather. Motivation for
@@ -108,7 +119,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
108119

109120
# prepare_input_cls with specific submodule fqn
110121
sp_model2 = copy.deepcopy(toy_model)
111-
sp_model2 = convert_to_float8_training(sp_model2, config=config)
122+
convert_model_func(sp_model2, config=config)
112123

113124
if not allgather_in_lowp:
114125
prepare_input = prepare_input_cls(

0 commit comments

Comments
 (0)