Skip to content

mxfp8 training: add TP sharding strategy for dim1 kernel #2436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
4 changes: 2 additions & 2 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
loss.backward()


def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32):
tensorwise_config = Float8LinearConfig(emulate=True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True
Expand All @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
)


def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32):
tensorwise_config = Float8LinearConfig(emulate=True)
_test_lowp_mlp_tensor_parallelism_base(
mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True
Expand Down
6 changes: 3 additions & 3 deletions test/float8/test_fsdp2_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
enable_fsdp_float8_all_gather=True,
)

toy_model = ToyModel().to(device)
toy_model = ToyModel(size).to(device)

tp_model = copy.deepcopy(toy_model)
tp_model = convert_to_float8_training(tp_model, config=config)
Expand Down Expand Up @@ -94,11 +94,11 @@ def _test_fp8_mlp_tensor_parallelism_base(
# TODO(future PR): test numerics, and add more cases


def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)


def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)


Expand Down
19 changes: 17 additions & 2 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
)


def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
config.block_size = 16
config.block_size = 32
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)
Expand All @@ -79,11 +79,26 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
)


def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
config.block_size = 32
config.use_fp8_dim1_cast_triton_kernel = True
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)
# TODO(future PR): enable compile here, currently seeing
# https://www.internalfb.com/phabricator/paste/view/P1851219639
# _test_lowp_mlp_tensor_parallelism_base(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to uncomment this and run ./test/prototype/mx_formats/test_mx_dtensor.sh to reproduce

# mesh, config, size, compile=True, allgather_in_lowp=False
# )


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp8,
_test_mxfp8_mlp_tensor_parallelism,
_test_mxfp8_mlp_tensor_parallelism_dim1_triton,
]

for test in tqdm(tests, desc="Running tests"):
Expand Down
18 changes: 17 additions & 1 deletion torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,8 @@ def triton_to_mxfp8_dim1(
* `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1
"""
assert x.is_contiguous(), "`x` must be contiguous"
assert x.dtype == torch.bfloat16
# TODO(before land): maybe gate by FakeTensor below?
# assert x.dtype == torch.bfloat16
assert inner_block_size <= 32

# Get tensor shape
Expand Down Expand Up @@ -1363,6 +1364,21 @@ def triton_to_mxfp8_dim1(
col_scale.view(torch.float8_e8m0fnu),
)

# print(torch.ops.torchao.triton_to_mxfp8_dim1.default)

from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.experimental import register_sharding

@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32):
replicate = ([Replicate(), Replicate()], [Replicate(), None])
# Note that the data is returned transposed, which is why
# we flip the sharding dim below
shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None])
shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None])
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
return acceptable_shardings

def triton_to_mxfp8_dim1_reference(
x_hp: torch.Tensor, block_size
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
80 changes: 49 additions & 31 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.nn.functional as F
from torch.distributed._tensor import DTensor

from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
Expand All @@ -25,6 +26,46 @@
)


def _triton_to_mxfp8_dim1_wrapper(
a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice
):
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
if isinstance(a_data, DTensor):
assert isinstance(a_scale, DTensor)
a_data_local = a_data.to_local()
a_scale_local = a_scale.to_local()
inner = MXTensor(
a_scale_local,
a_data_local.t(),
elem_dtype,
block_size,
hp_dtype,
False,
gemm_kernel_choice,
False,
)
mx_tensor = DTensor.from_local(
inner,
a_data.device_mesh,
a_data.placements,
run_check=False,
shape=a_data.t().size(),
stride=a_data.t().stride(),
)
else:
mx_tensor = MXTensor(
a_scale,
a_data.t(),
elem_dtype,
block_size,
hp_dtype,
False,
gemm_kernel_choice,
False,
)
return mx_tensor


@torch._dynamo.allow_in_graph
class mx_mm(torch.autograd.Function):
# There are three gemms in a forward + backward of a Linear layer:
Expand Down Expand Up @@ -95,20 +136,9 @@ def backward(ctx, grad_output_hp: torch.Tensor):
)

if use_fp8_dim1_cast_triton_kernel:
weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1(
weight_hp, block_size
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
)
weight_mx_dim1 = MXTensor(
weight_mx_dim1_scale.reshape(-1),
weight_mx_dim1_data.t(),
w_elem_dtype,
block_size,
weight_hp.dtype,
False,
gemm_kernel_choice,
False,
)

else:
weight_hp_t_c = weight_hp.t().contiguous()
weight_mx_dim1 = MXTensor.to_mx(
Expand All @@ -124,18 +154,12 @@ def backward(ctx, grad_output_hp: torch.Tensor):

# input_t @ grad_output = grad_weight
if use_fp8_dim1_cast_triton_kernel:
grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1(
grad_output_hp_r, block_size
)
grad_output_mx_dim1 = MXTensor(
grad_output_mx_dim1_scale.reshape(-1),
grad_output_mx_dim1_data.t(),
grad_elem_dtype,
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
grad_output_hp_r,
block_size,
grad_elem_dtype,
grad_output_hp_r.dtype,
False,
gemm_kernel_choice,
False,
)
else:
grad_output_mx_dim1 = MXTensor.to_mx(
Expand All @@ -146,18 +170,12 @@ def backward(ctx, grad_output_hp: torch.Tensor):
)

if use_fp8_dim1_cast_triton_kernel:
input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1(
input_hp_r, block_size
)
input_t_mx_dim0_tmp = MXTensor(
input_t_mx_dim0_tmp_scale.reshape(-1),
input_t_mx_dim0_tmp_data.t(),
in_elem_dtype,
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
input_hp_r,
block_size,
in_elem_dtype,
input_hp_r.dtype,
False,
gemm_kernel_choice,
False,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
else:
Expand Down
20 changes: 10 additions & 10 deletions torchao/testing/training/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
class FeedForward(nn.Module):
"""MLP based model"""

def __init__(self):
def __init__(self, size):
super(FeedForward, self).__init__()
self.w1 = nn.Linear(16, 32, bias=False)
self.w2 = nn.Linear(16, 32, bias=False)
self.out_proj = nn.Linear(32, 16, bias=False)
self.w1 = nn.Linear(size, size * 2, bias=False)
self.w2 = nn.Linear(size, size * 2, bias=False)
self.out_proj = nn.Linear(size * 2, size, bias=False)

def forward(self, x):
x = F.silu(self.w1(x)) * self.w2(x)
Expand All @@ -45,9 +45,9 @@ def forward(self, x):


class ToyModel(nn.Module):
def __init__(self):
def __init__(self, size):
super(ToyModel, self).__init__()
self.ffn = FeedForward()
self.ffn = FeedForward(size)

def forward(self, x):
return self.ffn(x)
Expand All @@ -56,7 +56,7 @@ def forward(self, x):
def _test_lowp_mlp_tensor_parallelism_base(
mesh: DeviceMesh,
config: Union[Float8LinearConfig, MXLinearConfig],
size=16,
size=32,
compile: bool = False,
allgather_in_lowp: bool = False,
):
Expand All @@ -67,7 +67,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
if isinstance(config, MXLinearConfig):
convert_model_func = quantize_

toy_model = ToyModel().to(device)
toy_model = ToyModel(size).to(device)
toy_model_fp8 = copy.deepcopy(toy_model)
convert_model_func(toy_model_fp8, config=config)

Expand Down Expand Up @@ -151,8 +151,8 @@ def _test_lowp_mlp_tensor_parallelism_base(
sp_model = torch.compile(sp_model)
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32 = torch.rand(1, size * 2, size, device=device, requires_grad=False)
go_fp32 = torch.rand(1, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
go_fp32_tp = go_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
Expand Down
Loading