Skip to content

Commit

Permalink
Support MXFP8 all-gather with only column-wise data
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 committed Feb 25, 2025
1 parent 2099726 commit 7f4dfdb
Showing 1 changed file with 59 additions and 40 deletions.
99 changes: 59 additions & 40 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,20 +923,27 @@ def _all_gather_mxfp8(
if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:]

# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage and not quantizer.columnwise_usage:
# Cast input tensor to MXFP8 with required data
if not isinstance(input_, MXFP8TensorBase):
input_ = quantizer(input_)
elif (
input_.rowwise_data is None and quantizer.rowwise_usage
or input_.columnwise_data is None and quantizer.columnwise_usage
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to MXFP8."
)
input_ = quantizer(input_.dequantize())

# Cast input tensor to MXFP8 if needed
if not isinstance(input_, MXFP8TensorBase):
input_ = quantizer(input_)
# Construct MXFP8 output tensor
out = quantizer.make_empty(out_shape, dtype=input_.dtype, device=input._device)

# Construct MXFP8 output tensor
dtype = torch.float32
device = "cuda"
if isinstance(input_, MXFP8Tensor):
dtype = input_.dtype
device = input_.device
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Async op handle
handle = None

# Gather MXFP8 data for row-wise usage
if quantizer.rowwise_usage:

# Remove padding from MXFP8 scale-inverses
in_scale_inv = input_._rowwise_scale_inv
Expand All @@ -948,36 +955,48 @@ def _all_gather_mxfp8(
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

# Launch all-gathers
with torch.distributed._coalescing_manager(
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
input_._rowwise_data,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = coalescing_manager if async_op else None
return out, handle
)
handle = torch.distributed.all_gather_into_tensor(
out._rowwise_data,
input_._rowwise_data,
group=process_group,
async_op=async_op,
)

# Gather in high precision and quantize for column-wise usage
if isinstance(input_, QuantizedTensor):
input_ = input_.dequantize(dtype=torch.bfloat16)
out = torch.empty(
out_shape,
dtype=input_.dtype,
device=input_.device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, input_, group=process_group)
out = quantizer(out)
return out, None
# Gather MXFP8 data for column-wise usage
if quantizer.columnwise_usage:

# Remove padding from MXFP8 scale-inverses
in_scale_inv = input_._columnwise_scale_inv
out_scale_inv = out._columnwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]

# Launch all-gathers
if handle is not None:
handle.wait()
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
handle = torch.distributed.all_gather_into_tensor(
out._columnwise_data,
input_._columnwise_data,
group=process_group,
async_op=async_op,
)

return out, handle


def gather_along_first_dim(
Expand Down

0 comments on commit 7f4dfdb

Please sign in to comment.