-
Notifications
You must be signed in to change notification settings - Fork 290
Some changes in inner-padding
option
#858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,14 +40,25 @@ def _test_compile_base( | |
fullgraph: bool, | ||
config: Float8LinearConfig, | ||
dtype: torch.dtype, | ||
pad_dimensions: bool, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, I think we should test padding N, K, and N and K together |
||
): | ||
random.seed(0) | ||
torch.manual_seed(0) | ||
x_shape = (16, 16) | ||
|
||
if pad_dimensions: | ||
x_shape = (17, 17) | ||
else: | ||
x_shape = (16, 16) | ||
|
||
linear_dtype = torch.bfloat16 | ||
|
||
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) | ||
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) | ||
|
||
if pad_dimensions: | ||
m_ref = nn.Linear(17, 35, bias=True, device="cuda", dtype=linear_dtype) | ||
else: | ||
m_ref = nn.Linear(16, 16, bias=True, device="cuda", dtype=linear_dtype) | ||
|
||
|
||
m_fp8 = Float8Linear.from_float( | ||
copy.deepcopy(m_ref), | ||
|
@@ -71,6 +82,7 @@ def _get_config( | |
scaling_type_weight, | ||
scaling_type_grad_output, | ||
emulate, | ||
pad_dimensions, | ||
): | ||
if scaling_type_input is ScalingType.STATIC: | ||
cast_config_input = CastConfig( | ||
|
@@ -99,11 +111,13 @@ def _get_config( | |
cast_config_weight=cast_config_weight, | ||
cast_config_grad_output=cast_config_grad_output, | ||
emulate=emulate, | ||
pad_dimensions=pad_dimensions, | ||
) | ||
return config | ||
|
||
|
||
@pytest.mark.parametrize("fullgraph", [True]) | ||
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) | ||
@pytest.mark.parametrize( | ||
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] | ||
) | ||
|
@@ -113,7 +127,9 @@ def _get_config( | |
@pytest.mark.parametrize( | ||
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] | ||
) | ||
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) | ||
@pytest.mark.parametrize( | ||
"pad_dimensions", [False, True] | ||
) | ||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) | ||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||
def test_eager_only( | ||
|
@@ -122,17 +138,19 @@ def test_eager_only( | |
scaling_type_input: ScalingType, | ||
scaling_type_weight: ScalingType, | ||
scaling_type_grad_output: ScalingType, | ||
pad_dimensions: bool, | ||
dtype: torch.dtype, | ||
): | ||
torch._dynamo.reset() | ||
config = _get_config( | ||
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, | ||
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions, | ||
) | ||
_test_compile_base( | ||
"eager", | ||
fullgraph, | ||
config, | ||
dtype, | ||
pad_dimensions, | ||
) | ||
|
||
|
||
|
@@ -147,6 +165,9 @@ def test_eager_only( | |
@pytest.mark.parametrize( | ||
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] | ||
) | ||
@pytest.mark.parametrize( | ||
"pad_dimensions", [False, True] | ||
) | ||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) | ||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||
def test_aot_eager( | ||
|
@@ -155,17 +176,19 @@ def test_aot_eager( | |
scaling_type_input: ScalingType, | ||
scaling_type_weight: ScalingType, | ||
scaling_type_grad_output: ScalingType, | ||
pad_dimensions: bool, | ||
dtype: torch.dtype, | ||
): | ||
torch._dynamo.reset() | ||
config = _get_config( | ||
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, | ||
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions, | ||
) | ||
_test_compile_base( | ||
"aot_eager", | ||
fullgraph, | ||
config, | ||
dtype, | ||
pad_dimensions, | ||
) | ||
|
||
|
||
|
@@ -180,6 +203,9 @@ def test_aot_eager( | |
@pytest.mark.parametrize( | ||
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] | ||
) | ||
@pytest.mark.parametrize( | ||
"pad_dimensions", [True, False] | ||
) | ||
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") | ||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) | ||
def test_inductor( | ||
|
@@ -188,17 +214,19 @@ def test_inductor( | |
scaling_type_input: ScalingType, | ||
scaling_type_weight: ScalingType, | ||
scaling_type_grad_output: ScalingType, | ||
pad_dimensions: bool, | ||
dtype: torch.dtype, | ||
): | ||
torch._dynamo.reset() | ||
config = _get_config( | ||
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, | ||
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions, | ||
) | ||
_test_compile_base( | ||
"inductor", | ||
fullgraph, | ||
config, | ||
dtype, | ||
pad_dimensions, | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,7 +120,7 @@ class Float8LinearConfig: | |
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls | ||
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. | ||
# This can cause a memory spike however so we keep this off by default. | ||
pad_inner_dim: bool = False | ||
pad_dimensions: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: update docblock |
||
|
||
# If True, emulation is used instead of hardware accelerated gemm | ||
emulate: bool = False | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,27 +134,34 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): | |
a_scale = a._scale | ||
b_data = b._data | ||
|
||
out_shape = (a._data.size(0), b._data.size(1)) | ||
|
||
scaled_mm_config = choose_scaled_mm_config( | ||
a._gemm_input_role, | ||
a._linear_mm_config, | ||
b._gemm_input_role, | ||
b._linear_mm_config, | ||
) | ||
|
||
if scaled_mm_config.pad_inner_dim: | ||
if scaled_mm_config.pad_dimensions: | ||
assert a._data.size(1) == b._data.size( | ||
0 | ||
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" | ||
a_data = pad_tensor_for_matmul(a_data, dims=1) | ||
b_data = pad_tensor_for_matmul(b_data, dims=0) | ||
b_data = pad_tensor_for_matmul(b_data, dims=[0,1]) | ||
|
||
if not is_row_major(a_data.stride()): | ||
a_data = a_data.contiguous() | ||
if is_row_major(b_data.stride()): | ||
b_data = b_data.t().contiguous().t() | ||
b_scale = b._scale | ||
return a_data, a_scale, b_data, b_scale | ||
|
||
return a_data, a_scale, b_data, b_scale, out_shape | ||
|
||
def postprocess_addmm(out: torch.Tensor, scaled_mm_config, out_shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: just inline instead of creating a new function? it's only two lines of code and used once |
||
if scaled_mm_config.pad_dimensions: | ||
out = out[:, :out_shape[1]] | ||
return out | ||
|
||
@implements([aten.mm.default, aten.matmul.default]) | ||
def float8_mm(aten_op, args, kwargs=None): | ||
|
@@ -166,7 +173,7 @@ def float8_mm(aten_op, args, kwargs=None): | |
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format( | ||
type(a), type(b) | ||
) | ||
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) | ||
a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b) | ||
output_dtype = a._orig_dtype | ||
scaled_mm_config = choose_scaled_mm_config( | ||
a._gemm_input_role, | ||
|
@@ -188,6 +195,7 @@ def float8_mm(aten_op, args, kwargs=None): | |
bias=None, | ||
use_fast_accum=scaled_mm_config.use_fast_accum, | ||
) | ||
tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape) | ||
return tensor_out | ||
|
||
|
||
|
@@ -201,7 +209,7 @@ def float8_addmm(aten_op, args, kwargs=None): | |
bias = args[0] | ||
a = args[1] | ||
b = args[2] | ||
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) | ||
a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b) | ||
output_dtype = a._orig_dtype | ||
assert bias.dtype == output_dtype, "bias dtype must match output dtype" | ||
scaled_mm_config = choose_scaled_mm_config( | ||
|
@@ -225,6 +233,7 @@ def float8_addmm(aten_op, args, kwargs=None): | |
bias=bias, | ||
use_fast_accum=scaled_mm_config.use_fast_accum, | ||
) | ||
tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape) | ||
return tensor_out | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would now need 3 cases: