Skip to content

Commit 9ce26f8

Browse files
y-sqfacebook-github-bot
authored andcommitted
Test compile with inner-padding (#858)
Summary: Pull Request resolved: #858 Add test cases to verify that the compile of inner-padding works with the triton PR triton-lang/triton#4222. Before the triton PR, the inductor code-gen kernel fails at ``` tmp10 = tl.where(tmp6, tmp8, tmp9) TypeError: unexpected type fp8e5 and fp8e5 ``` Reviewed By: irobert0126 Differential Revision: D62003827
1 parent e283743 commit 9ce26f8

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

test/float8/test_compile.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ def _test_compile_base(
4040
fullgraph: bool,
4141
config: Float8LinearConfig,
4242
dtype: torch.dtype,
43+
pad_inner_dim: bool,
4344
):
4445
random.seed(0)
4546
torch.manual_seed(0)
4647
x_shape = (16, 16)
48+
if pad_inner_dim:
49+
x_shape = (17, 16)
4750
linear_dtype = torch.bfloat16
4851

4952
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
@@ -70,6 +73,7 @@ def _get_config(
7073
scaling_type_input,
7174
scaling_type_weight,
7275
scaling_type_grad_output,
76+
pad_inner_dim,
7377
emulate,
7478
):
7579
if scaling_type_input is ScalingType.STATIC:
@@ -99,6 +103,7 @@ def _get_config(
99103
cast_config_weight=cast_config_weight,
100104
cast_config_grad_output=cast_config_grad_output,
101105
emulate=emulate,
106+
pad_inner_dim=pad_inner_dim,
102107
)
103108
return config
104109

@@ -114,6 +119,9 @@ def _get_config(
114119
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
115120
)
116121
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
122+
@pytest.mark.parametrize(
123+
"pad_inner_dim", [True, False]
124+
)
117125
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
118126
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
119127
def test_eager_only(
@@ -122,17 +130,19 @@ def test_eager_only(
122130
scaling_type_input: ScalingType,
123131
scaling_type_weight: ScalingType,
124132
scaling_type_grad_output: ScalingType,
133+
pad_inner_dim: bool,
125134
dtype: torch.dtype,
126135
):
127136
torch._dynamo.reset()
128137
config = _get_config(
129-
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
138+
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,pad_inner_dim,
130139
)
131140
_test_compile_base(
132141
"eager",
133142
fullgraph,
134143
config,
135144
dtype,
145+
pad_inner_dim,
136146
)
137147

138148

@@ -147,6 +157,9 @@ def test_eager_only(
147157
@pytest.mark.parametrize(
148158
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
149159
)
160+
@pytest.mark.parametrize(
161+
"pad_inner_dim", [True, False]
162+
)
150163
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
151164
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
152165
def test_aot_eager(
@@ -155,17 +168,19 @@ def test_aot_eager(
155168
scaling_type_input: ScalingType,
156169
scaling_type_weight: ScalingType,
157170
scaling_type_grad_output: ScalingType,
171+
pad_inner_dim: bool,
158172
dtype: torch.dtype,
159173
):
160174
torch._dynamo.reset()
161175
config = _get_config(
162-
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
176+
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,pad_inner_dim,
163177
)
164178
_test_compile_base(
165179
"aot_eager",
166180
fullgraph,
167181
config,
168182
dtype,
183+
pad_inner_dim,
169184
)
170185

171186

@@ -180,6 +195,9 @@ def test_aot_eager(
180195
@pytest.mark.parametrize(
181196
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
182197
)
198+
@pytest.mark.parametrize(
199+
"pad_inner_dim", [False, True]
200+
)
183201
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
184202
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
185203
def test_inductor(
@@ -188,17 +206,19 @@ def test_inductor(
188206
scaling_type_input: ScalingType,
189207
scaling_type_weight: ScalingType,
190208
scaling_type_grad_output: ScalingType,
209+
pad_inner_dim: bool,
191210
dtype: torch.dtype,
192211
):
193212
torch._dynamo.reset()
194213
config = _get_config(
195-
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
214+
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_inner_dim,
196215
)
197216
_test_compile_base(
198217
"inductor",
199218
fullgraph,
200219
config,
201220
dtype,
221+
pad_inner_dim,
202222
)
203223

204224

0 commit comments

Comments
 (0)