Skip to content

Commit e7d334f

Browse files
committed
enable tensor parallelism for MXLinear
Summary: Enables TP for MXLinear. Specifically: 1. change the reshape logic from `x.reshape(-1, block_size)` to `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)` 2. modify the rest of the code to adhere to (1) 3. cast input tensor and max_abs to float32 before calculating the MX scale, in order to get around another bug in DTensor + view + int16 target type (1) is necessary because the old reshape logic would flatten dims, which did not work if one of those flattened dims was sharded. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 860833d ghstack-comment-id: 3000664086 Pull Request resolved: #2434
1 parent d842fd4 commit e7d334f

File tree

6 files changed

+45
-42
lines changed

6 files changed

+45
-42
lines changed

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,24 +68,22 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
6868
)
6969

7070

71-
def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
71+
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
7272
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
73-
# TODO(future PR): assert that the K dim must be divisible by block size,
74-
# today this is silently incorrect if block_size is greater than K
7573
config.block_size = 16
7674
_test_lowp_mlp_tensor_parallelism_base(
7775
mesh, config, size, compile=False, allgather_in_lowp=False
7876
)
79-
80-
# TODO(future PR): compile
77+
_test_lowp_mlp_tensor_parallelism_base(
78+
mesh, config, size, compile=True, allgather_in_lowp=False
79+
)
8180

8281

8382
if __name__ == "__main__":
8483
device_mesh = setup_distributed()
8584
tests = [
8685
_test_dtensor_cast_to_mxfp8,
87-
# TODO(next PR): enable this (current PR got too large, so splitting)
88-
# _test_mxfp8_mlp_tensor_parallelism_eager,
86+
_test_mxfp8_mlp_tensor_parallelism,
8987
]
9088

9189
for test in tqdm(tests, desc="Running tests"):

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
190190
# TODO(future): enable compile support
191191
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
192192
def test_activation_checkpointing():
193-
input_shape = (2, 4)
194-
grad_shape = (2, 8)
193+
input_shape = (16, 4)
194+
grad_shape = (16, 8)
195195
elem_dtype = torch.float8_e4m3fn
196196

197197
m = nn.Sequential(

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7272
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7373
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
7474
def test_hello_world(elem_dtype):
75-
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
75+
data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16)
7676
block_size = 4
7777
_test_mx(data, elem_dtype, block_size)
7878

torchao/prototype/mx_formats/kernels.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
10561056

10571057
# effective mx block size since we're packing 2 fp4 into 1 uint8
10581058
packed_mx_block_size = 3 * mx_block_size // 4
1059-
packed_shape = [uint8_data.shape[0], packed_mx_block_size]
1059+
packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size]
10601060
n_mx_blocks = uint8_data.numel() // mx_block_size
10611061

10621062
grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)
@@ -1337,7 +1337,9 @@ def triton_to_mxfp8_dim1(
13371337

13381338
# Create scale tensors
13391339
col_scale = torch.empty(
1340-
(n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device
1340+
(n_cols, n_rows // inner_block_size, 1),
1341+
dtype=torch.uint8,
1342+
device=x.device,
13411343
)
13421344

13431345
# Calculate grid dimensions based on tile size
@@ -1374,7 +1376,7 @@ def triton_to_mxfp8_dim1_reference(
13741376
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
13751377
x_hp_d1, torch.float8_e4m3fn, block_size
13761378
)
1377-
scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu)
1379+
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
13781380
return (
13791381
x_hp_d1_normalized.t(),
13801382
scale_e8m0_dim1,

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
2727
from torchao.prototype.mx_formats.constants import (
28-
BF16_EXP_BIAS,
2928
BLOCK_SIZE_DEFAULT,
3029
DTYPE_FP6_E2M3,
3130
DTYPE_FP6_E3M2,
@@ -62,7 +61,6 @@
6261

6362
# TODO(later): read from somewhere else?
6463
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
65-
EBITS_BF16, MBITS_BF16 = 8, 7
6664
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
6765
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
6866
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
@@ -137,9 +135,7 @@ def _to_mx_rceil(
137135
)
138136

139137
# scale and saturated cast the data elements to max of target dtype
140-
data_lp = torch.clamp(
141-
data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
142-
)
138+
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
143139
return exponent, data_lp
144140

145141

@@ -160,22 +156,33 @@ def to_mx(
160156
torch.float,
161157
), f"{data_hp.dtype} is not supported yet"
162158
# TODO(future PR): consider supporting padding
163-
assert data_hp.numel() % block_size == 0, "unsupported"
159+
assert data_hp.shape[-1] % block_size == 0, (
160+
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
161+
)
164162
assert data_hp.is_contiguous(), "unsupported"
165163
assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported"
166164

167-
# calculate the scale in e8m0 format
168-
169165
orig_shape = data_hp.shape
170-
# TODO(future PR): fix this line for TP, currently this reshape does not work
171-
# for rank 3 tensor where dim1 is sharded
172-
data_hp = data_hp.reshape(-1, block_size)
166+
data_hp = data_hp.reshape(
167+
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
168+
)
173169

174170
# find max value of the data
175171
# Note: this only implements the `minimally supported` version of
176172
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
177173
# section 6.3.
178-
max_abs = torch.amax(torch.abs(data_hp), 1)
174+
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
175+
176+
# We cast to float32 here because
177+
# in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below,
178+
# if tensor parallel is enabled then the resulting shape is 2x larger
179+
# than it should be under some conditions, likely because of a bug in
180+
# the `view` op with DTensor and target dtype int16. I reproduce in
181+
# torchtitan but not in a unit test, so not enough info to file a good
182+
# issue in pytorch/pytorch. For now, work around. In the future we should
183+
# debug and fix this properly.
184+
data_hp = data_hp.to(torch.float32)
185+
max_abs = max_abs.to(torch.float32)
179186

180187
# Set X to be the largest power-of-two less than or equal to
181188
# max_abs(v), divided by the largest power of two representable
@@ -206,17 +213,11 @@ def to_mx(
206213
if scaling_mode == ScaleCalculationMode.RCEIL:
207214
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
208215
else:
209-
if data_hp.dtype is torch.float32:
210-
hp_int_dtype = torch.int32
211-
hp_mbits = MBITS_F32
212-
hp_ebits = EBITS_F32
213-
hp_exp_bias = F32_EXP_BIAS
214-
else:
215-
assert data_hp.dtype is torch.bfloat16
216-
hp_int_dtype = torch.int16
217-
hp_mbits = MBITS_BF16
218-
hp_ebits = EBITS_BF16
219-
hp_exp_bias = BF16_EXP_BIAS
216+
assert data_hp.dtype is torch.float32
217+
hp_int_dtype = torch.int32
218+
hp_mbits = MBITS_F32
219+
hp_ebits = EBITS_F32
220+
hp_exp_bias = F32_EXP_BIAS
220221

221222
# rounding before calculating the largest power of 2
222223
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
@@ -285,7 +286,7 @@ def to_mx(
285286
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)
286287

287288
# scale and saturated cast the data elements to max of target dtype
288-
data_lp = data_hp / scale_fp32.unsqueeze(1)
289+
data_lp = data_hp / scale_fp32
289290

290291
if (
291292
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
@@ -511,7 +512,6 @@ def __new__(
511512
assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, (
512513
f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
513514
)
514-
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
515515
assert data_bits.dtype in (
516516
torch.float8_e4m3fn,
517517
torch.float8_e5m2,

torchao/testing/training/dtensor_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,18 @@ def _test_lowp_mlp_tensor_parallelism_base(
152152
sp_model2 = torch.compile(sp_model2)
153153

154154
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
155+
go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
155156
x_fp32_tp_input = x_fp32.clone()
157+
go_fp32_tp = go_fp32.clone()
156158
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
159+
go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)])
157160

158161
tp_out = tp_model(x_fp32_tp_input)
159-
tp_out.sum().backward()
162+
tp_out.backward(go_fp32_tp)
160163
sp_out = sp_model(x_fp32_sp_input)
161-
sp_out.sum().backward()
164+
sp_out.backward(go_fp32_sp)
162165
global_out = toy_model_fp8(x_fp32)
163-
global_out.sum().backward()
166+
global_out.backward(go_fp32)
164167
torch.testing.assert_close(tp_out, global_out)
165168
torch.testing.assert_close(sp_out.full_tensor(), global_out)
166169
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
@@ -169,7 +172,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
169172
)
170173

171174
sp_out2 = sp_model2(x_fp32_sp_input)
172-
sp_out2.sum().backward()
175+
sp_out2.backward(go_fp32_sp)
173176
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
174177
torch.testing.assert_close(
175178
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad

0 commit comments

Comments
 (0)