Skip to content

enable to_mxfp8 cast for DTensor #2420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Test numerics of manually defined float16 TP vs mxfp8 TP of toy models

Note: for now, this does not run in CI.
TODO(future): make this run in CI
"""

import os

import pytest
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_7

if not TORCH_VERSION_AT_LEAST_2_7:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torch.distributed._tensor import DTensor, Shard, distribute_tensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from tqdm import tqdm

from torchao.prototype.mx_formats import MXLinearConfig
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
)

torch.set_float32_matmul_precision("high")


def setup_distributed():
world_size = int(os.environ.get("WORLD_SIZE", -1))
device_mesh = init_device_mesh("cuda", (world_size,))
# seed must be the same in all processes
torch.manual_seed(1)
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
return device_mesh


def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
device = mesh.device_type

x_fp32 = torch.rand(size, size, device=device)
x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=size // 2)

dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=size // 2)
assert isinstance(dist_x_fp8, DTensor)

# Verify that the result of to_mx with DTensor matches the slice of the
# result of to_mx without DTensor. This will fail on numeric op mismatches.
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert size % world_size == 0, "unsupported"
x_fp8_fp32 = x_fp8.to_dtype(torch.float32)
rows_per_slice = size // world_size
slice_start = local_rank * rows_per_slice
slice_end = (local_rank + 1) * rows_per_slice
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end]
torch.testing.assert_close(
x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0
)


def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
# TODO(future PR): assert that the K dim must be divisible by block size,
# today this is silently incorrect if block_size is greater than K
config.block_size = 16
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)

# TODO(future PR): compile


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp8,
# TODO(next PR): enable this (current PR got too large, so splitting)
# _test_mxfp8_mlp_tensor_parallelism_eager,
]

for test in tqdm(tests, desc="Running tests"):
try:
test(device_mesh)
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e

torch.distributed.destroy_process_group()
17 changes: 17 additions & 0 deletions test/prototype/mx_formats/test_mx_dtensor.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
#!/bin/bash

# terminate script on first error
set -e

if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
echo "Skipping test_dtensor.sh because no CUDA devices are available."
exit
fi

# integration tests for TP/SP
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py
3 changes: 0 additions & 3 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,15 +1102,12 @@ def _triton_calculate_scale(x, axis):
bf16_mbits = 7
bf16_exp_bias = 127
fp32_mbits = 23
# We use a small epsilon to avoid division by zero
epsilon = 1e-10

# Find the maximum absolute value for each row
max_abs = tl.max(x, axis=axis)

# Calculate the e8m0 scale by extracting the exponent (floor)
# TODO(future PR): support other exponent extraction types (ceil, RNE)
max_abs = max_abs + epsilon
max_abs = max_abs.to(tl.bfloat16)
max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias
Expand Down
45 changes: 36 additions & 9 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Callable, Dict, Union

import torch
from torch.distributed._tensor import DTensor

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
Expand Down Expand Up @@ -166,6 +167,8 @@ def to_mx(
# calculate the scale in e8m0 format

orig_shape = data_hp.shape
# TODO(future PR): fix this line for TP, currently this reshape does not work
# for rank 3 tensor where dim1 is sharded
data_hp = data_hp.reshape(-1, block_size)

# find max value of the data
Expand All @@ -174,10 +177,6 @@ def to_mx(
# section 6.3.
max_abs = torch.amax(torch.abs(data_hp), 1)

# Add an epsilon to prevent the log2 function call for returning -inf
# where the values are zero.
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)

# Set X to be the largest power-of-two less than or equal to
# max_abs(v), divided by the largest power of two representable
# in the element data type, and get the mbits at the same time
Expand Down Expand Up @@ -233,8 +232,12 @@ def to_mx(
)

# Calculate the scale for different modes
max_abs_int32 = (max_abs + eps).view(hp_int_dtype)
extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias
max_abs_int32 = max_abs.view(hp_int_dtype)
# For now, use `torch.bitwise_right_shift` instead of `>>` to support DTensor
# See https://github.com/pytorch/pytorch/issues/156533.
extracted_pow2 = (
(torch.bitwise_right_shift(max_abs_int32, hp_mbits)) & 0b11111111
) - hp_exp_bias

if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
Expand Down Expand Up @@ -266,9 +269,11 @@ def to_mx(
)

# For now, calculate the scale in floating point.
scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view(
torch.float32
)
# For now, use `torch.bitwise_left_shift` instead of `<<` to support DTensor
# See https://github.com/pytorch/pytorch/issues/156533.
scale_fp32 = (
torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32)
).view(torch.float32)

# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
# float32 denormal range. For now, manually adjust the fp scale. This is
Expand Down Expand Up @@ -597,6 +602,28 @@ def to_mx(
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
)
if isinstance(scale_e8m0_biased, DTensor):
assert isinstance(data_lp, DTensor), "unsupported"
local_scale_e8m0_biased = scale_e8m0_biased.to_local()
local_data_lp = data_lp.to_local()
inner_mx_tensor = MXTensor(
local_scale_e8m0_biased,
local_data_lp,
elem_dtype,
block_size,
data_hp.dtype,
use_fp4_custom_triton_dequant_kernel,
gemm_kernel_choice,
pack_fp6,
)
return DTensor.from_local(
inner_mx_tensor,
data_lp.device_mesh,
data_lp.placements,
run_check=False,
shape=data_lp.size(),
stride=data_lp.stride(),
)
return MXTensor(
scale_e8m0_biased,
data_lp,
Expand Down
23 changes: 17 additions & 6 deletions torchao/testing/training/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Union

import torch
import torch.nn as nn
Expand All @@ -24,6 +25,8 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.quantization import quantize_


class FeedForward(nn.Module):
Expand All @@ -36,7 +39,9 @@ def __init__(self):
self.out_proj = nn.Linear(32, 16, bias=False)

def forward(self, x):
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))
x = F.silu(self.w1(x)) * self.w2(x)
x = self.out_proj(x)
return x


class ToyModel(nn.Module):
Expand All @@ -50,20 +55,26 @@ def forward(self, x):

def _test_lowp_mlp_tensor_parallelism_base(
mesh: DeviceMesh,
config: Float8LinearConfig,
config: Union[Float8LinearConfig, MXLinearConfig],
size=16,
compile: bool = False,
allgather_in_lowp: bool = False,
):
device = mesh.device_type

# TODO(future): remove this once float8 training works with `quantize_` API
convert_model_func = convert_to_float8_training
if isinstance(config, MXLinearConfig):
convert_model_func = quantize_

toy_model = ToyModel().to(device)
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)
toy_model_fp8 = copy.deepcopy(toy_model)
convert_model_func(toy_model_fp8, config=config)

tp_model = copy.deepcopy(toy_model)
tp_model = convert_to_float8_training(tp_model, config=config)
convert_model_func(tp_model, config=config)
sp_model = copy.deepcopy(toy_model)
sp_model = convert_to_float8_training(sp_model, config=config)
convert_model_func(sp_model, config=config)

# For tensorwise scaling, enable float8 all_gather.
# For rowwise scaling, keep high precision all_gather. Motivation for
Expand Down Expand Up @@ -108,7 +119,7 @@ def _test_lowp_mlp_tensor_parallelism_base(

# prepare_input_cls with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = convert_to_float8_training(sp_model2, config=config)
convert_model_func(sp_model2, config=config)

if not allgather_in_lowp:
prepare_input = prepare_input_cls(
Expand Down
Loading