Skip to content

enable tensor parallelism for MXLinear #2434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 24, 2025
12 changes: 5 additions & 7 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
@@ -68,24 +68,22 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
)


def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
# TODO(future PR): assert that the K dim must be divisible by block size,
# today this is silently incorrect if block_size is greater than K
config.block_size = 16
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)

# TODO(future PR): compile
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=True, allgather_in_lowp=False
)


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp8,
# TODO(next PR): enable this (current PR got too large, so splitting)
# _test_mxfp8_mlp_tensor_parallelism_eager,
_test_mxfp8_mlp_tensor_parallelism,
]

for test in tqdm(tests, desc="Running tests"):
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
@@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
# TODO(future): enable compile support
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 8)
input_shape = (16, 4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was broken before, caught by enforcing that inner dim is divisible by block size

grad_shape = (16, 8)
elem_dtype = torch.float8_e4m3fn

m = nn.Sequential(
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_hello_world(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16)
block_size = 4
_test_mx(data, elem_dtype, block_size)

8 changes: 5 additions & 3 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
@@ -1056,7 +1056,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:

# effective mx block size since we're packing 2 fp4 into 1 uint8
packed_mx_block_size = 3 * mx_block_size // 4
packed_shape = [uint8_data.shape[0], packed_mx_block_size]
packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size]
n_mx_blocks = uint8_data.numel() // mx_block_size

grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)
@@ -1337,7 +1337,9 @@ def triton_to_mxfp8_dim1(

# Create scale tensors
col_scale = torch.empty(
(n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device
(n_cols, n_rows // inner_block_size, 1),
dtype=torch.uint8,
device=x.device,
)

# Calculate grid dimensions based on tile size
@@ -1374,7 +1376,7 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
x_hp_d1, torch.float8_e4m3fn, block_size
)
scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu)
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
return (
x_hp_d1_normalized.t(),
scale_e8m0_dim1,
50 changes: 25 additions & 25 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
BF16_EXP_BIAS,
BLOCK_SIZE_DEFAULT,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
@@ -62,7 +61,6 @@

# TODO(later): read from somewhere else?
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
EBITS_BF16, MBITS_BF16 = 8, 7
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
@@ -137,9 +135,7 @@ def _to_mx_rceil(
)

# scale and saturated cast the data elements to max of target dtype
data_lp = torch.clamp(
data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
return exponent, data_lp


@@ -160,22 +156,33 @@ def to_mx(
torch.float,
), f"{data_hp.dtype} is not supported yet"
# TODO(future PR): consider supporting padding
assert data_hp.numel() % block_size == 0, "unsupported"
assert data_hp.shape[-1] % block_size == 0, (
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
)
assert data_hp.is_contiguous(), "unsupported"
assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported"

# calculate the scale in e8m0 format

orig_shape = data_hp.shape
# TODO(future PR): fix this line for TP, currently this reshape does not work
# for rank 3 tensor where dim1 is sharded
data_hp = data_hp.reshape(-1, block_size)
data_hp = data_hp.reshape(
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
)

# find max value of the data
# Note: this only implements the `minimally supported` version of
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
# section 6.3.
max_abs = torch.amax(torch.abs(data_hp), 1)
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)

# We cast to float32 here because
# in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below,
# if tensor parallel is enabled then the resulting shape is 2x larger
# than it should be under some conditions, likely because of a bug in
# the `view` op with DTensor and target dtype int16. I reproduce in
# torchtitan but not in a unit test, so not enough info to file a good
# issue in pytorch/pytorch. For now, work around. In the future we should
# debug and fix this properly.
data_hp = data_hp.to(torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

performance testing showed that with compile on, having this in float32 does not regress performance

max_abs = max_abs.to(torch.float32)

# Set X to be the largest power-of-two less than or equal to
# max_abs(v), divided by the largest power of two representable
@@ -206,17 +213,11 @@ def to_mx(
if scaling_mode == ScaleCalculationMode.RCEIL:
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
else:
if data_hp.dtype is torch.float32:
hp_int_dtype = torch.int32
hp_mbits = MBITS_F32
hp_ebits = EBITS_F32
hp_exp_bias = F32_EXP_BIAS
else:
assert data_hp.dtype is torch.bfloat16
hp_int_dtype = torch.int16
hp_mbits = MBITS_BF16
hp_ebits = EBITS_BF16
hp_exp_bias = BF16_EXP_BIAS
assert data_hp.dtype is torch.float32
hp_int_dtype = torch.int32
hp_mbits = MBITS_F32
hp_ebits = EBITS_F32
hp_exp_bias = F32_EXP_BIAS

# rounding before calculating the largest power of 2
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
@@ -285,7 +286,7 @@ def to_mx(
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)

# scale and saturated cast the data elements to max of target dtype
data_lp = data_hp / scale_fp32.unsqueeze(1)
data_lp = data_hp / scale_fp32

if (
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
@@ -511,7 +512,6 @@ def __new__(
assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, (
f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
)
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
assert data_bits.dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
11 changes: 7 additions & 4 deletions torchao/testing/training/dtensor_utils.py
Original file line number Diff line number Diff line change
@@ -152,15 +152,18 @@ def _test_lowp_mlp_tensor_parallelism_base(
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
go_fp32_tp = go_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
tp_out.backward(go_fp32_tp)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make sure grad flowing into the last linear is contiguous

sp_out = sp_model(x_fp32_sp_input)
sp_out.sum().backward()
sp_out.backward(go_fp32_sp)
global_out = toy_model_fp8(x_fp32)
global_out.sum().backward()
global_out.backward(go_fp32)
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
@@ -169,7 +172,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
)

sp_out2 = sp_model2(x_fp32_sp_input)
sp_out2.sum().backward()
sp_out2.backward(go_fp32_sp)
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
torch.testing.assert_close(
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad