Skip to content

Some changes in inner-padding option #858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/float8/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)

a_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
emulate=False, use_fast_accum=True, fp8_output=True, pad_dimensions=True
)
b_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
emulate=False, use_fast_accum=True, fp8_output=True, pad_dimensions=True
)
a_config = LinearMMConfig(a_config, a_config, a_config)
b_config = LinearMMConfig(b_config, b_config, b_config)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_different_configs_error(self):
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
def test_pad_dimensions(self, base_dtype, use_fast_accum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would now need 3 cases:

  • pad N only
  • pad K only
  • pad N and K

torch.manual_seed(42)
input_dtype = torch.float8_e4m3fn
compare_type = torch.float32
Expand Down
40 changes: 34 additions & 6 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,25 @@ def _test_compile_base(
fullgraph: bool,
config: Float8LinearConfig,
dtype: torch.dtype,
pad_dimensions: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, I think we should test padding N, K, and N and K together

):
random.seed(0)
torch.manual_seed(0)
x_shape = (16, 16)

if pad_dimensions:
x_shape = (17, 17)
else:
x_shape = (16, 16)

linear_dtype = torch.bfloat16

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)

if pad_dimensions:
m_ref = nn.Linear(17, 35, bias=True, device="cuda", dtype=linear_dtype)
else:
m_ref = nn.Linear(16, 16, bias=True, device="cuda", dtype=linear_dtype)


m_fp8 = Float8Linear.from_float(
copy.deepcopy(m_ref),
Expand All @@ -71,6 +82,7 @@ def _get_config(
scaling_type_weight,
scaling_type_grad_output,
emulate,
pad_dimensions,
):
if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
Expand Down Expand Up @@ -99,11 +111,13 @@ def _get_config(
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
emulate=emulate,
pad_dimensions=pad_dimensions,
)
return config


@pytest.mark.parametrize("fullgraph", [True])
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
Expand All @@ -113,7 +127,9 @@ def _get_config(
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"pad_dimensions", [False, True]
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_eager_only(
Expand All @@ -122,17 +138,19 @@ def test_eager_only(
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
pad_dimensions: bool,
dtype: torch.dtype,
):
torch._dynamo.reset()
config = _get_config(
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions,
)
_test_compile_base(
"eager",
fullgraph,
config,
dtype,
pad_dimensions,
)


Expand All @@ -147,6 +165,9 @@ def test_eager_only(
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
"pad_dimensions", [False, True]
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_aot_eager(
Expand All @@ -155,17 +176,19 @@ def test_aot_eager(
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
pad_dimensions: bool,
dtype: torch.dtype,
):
torch._dynamo.reset()
config = _get_config(
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions,
)
_test_compile_base(
"aot_eager",
fullgraph,
config,
dtype,
pad_dimensions,
)


Expand All @@ -180,6 +203,9 @@ def test_aot_eager(
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
"pad_dimensions", [True, False]
)
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_inductor(
Expand All @@ -188,17 +214,19 @@ def test_inductor(
scaling_type_input: ScalingType,
scaling_type_weight: ScalingType,
scaling_type_grad_output: ScalingType,
pad_dimensions: bool,
dtype: torch.dtype,
):
torch._dynamo.reset()
config = _get_config(
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions,
)
_test_compile_base(
"inductor",
fullgraph,
config,
dtype,
pad_dimensions,
)


Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Float8LinearConfig:
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
# This can cause a memory spike however so we keep this off by default.
pad_inner_dim: bool = False
pad_dimensions: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update docblock


# If True, emulation is used instead of hardware accelerated gemm
emulate: bool = False
Expand Down
6 changes: 3 additions & 3 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,21 @@ def __init__(self, *args, **kwargs):
emulate,
self.config.gemm_config_output.use_fast_accum,
False,
self.config.pad_inner_dim,
self.config.pad_dimensions,
),
# grad_input
ScaledMMConfig(
emulate,
self.config.gemm_config_grad_input.use_fast_accum,
False,
self.config.pad_inner_dim,
self.config.pad_dimensions,
),
# grad_weight
ScaledMMConfig(
emulate,
self.config.gemm_config_grad_weight.use_fast_accum,
False,
self.config.pad_inner_dim,
self.config.pad_dimensions,
),
)

Expand Down
19 changes: 14 additions & 5 deletions torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,34 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
a_scale = a._scale
b_data = b._data

out_shape = (a._data.size(0), b._data.size(1))

scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
a._linear_mm_config,
b._gemm_input_role,
b._linear_mm_config,
)

if scaled_mm_config.pad_inner_dim:
if scaled_mm_config.pad_dimensions:
assert a._data.size(1) == b._data.size(
0
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
a_data = pad_tensor_for_matmul(a_data, dims=1)
b_data = pad_tensor_for_matmul(b_data, dims=0)
b_data = pad_tensor_for_matmul(b_data, dims=[0,1])

if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
b_data = b_data.t().contiguous().t()
b_scale = b._scale
return a_data, a_scale, b_data, b_scale

return a_data, a_scale, b_data, b_scale, out_shape

def postprocess_addmm(out: torch.Tensor, scaled_mm_config, out_shape):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just inline instead of creating a new function? it's only two lines of code and used once

if scaled_mm_config.pad_dimensions:
out = out[:, :out_shape[1]]
return out

@implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
Expand All @@ -166,7 +173,7 @@ def float8_mm(aten_op, args, kwargs=None):
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
type(a), type(b)
)
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
Expand All @@ -188,6 +195,7 @@ def float8_mm(aten_op, args, kwargs=None):
bias=None,
use_fast_accum=scaled_mm_config.use_fast_accum,
)
tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape)
return tensor_out


Expand All @@ -201,7 +209,7 @@ def float8_addmm(aten_op, args, kwargs=None):
bias = args[0]
a = args[1]
b = args[2]
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
scaled_mm_config = choose_scaled_mm_config(
Expand All @@ -225,6 +233,7 @@ def float8_addmm(aten_op, args, kwargs=None):
bias=bias,
use_fast_accum=scaled_mm_config.use_fast_accum,
)
tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape)
return tensor_out


Expand Down
4 changes: 2 additions & 2 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ class ScaledMMConfig(NamedTuple):
emulate (bool): Whether to emulate the matmuls in fp32.
use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm.
fp8_output (bool): Whether to output the result of the scaled_mm in fp8.
pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s.
pad_dimensions (bool): Whether to pad the inner dimension of a and b with 0s.
This is needed for matmuls not aligned to 16.
"""

emulate: bool = False
use_fast_accum: bool = False
fp8_output: bool = False
pad_inner_dim: bool = False
pad_dimensions: bool = False


class LinearMMConfig(NamedTuple):
Expand Down
6 changes: 3 additions & 3 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ class Float8MMConfig(NamedTuple):
Attributes:
emulate (bool): Whether to emulate the matmuls in fp32.
use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm.
pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s.
pad_dimensions (bool): Whether to pad the inner dimension of a and b with 0s.
This is needed for matmuls not aligned to 16.
"""

emulate: bool = False
use_fast_accum: bool = False
pad_inner_dim: bool = False
pad_dimensions: bool = False


def preprocess_data(
Expand All @@ -44,7 +44,7 @@ def preprocess_data(
Returns:
Preprocessed tensors A and B in the format for _scaled_mm.
"""
if scaled_mm_config.pad_inner_dim:
if scaled_mm_config.pad_dimensions:
assert a_data.size(1) == b_data.size(
0
), f"Inner dims must match for mm, got {a_data.size(1)} and {b_data.size(0)}"
Expand Down
Loading