Skip to content

Commit 0fdfe3e

Browse files
y-sqfacebook-github-bot
authored andcommitted
Test compile with inner-padding
Summary: Add test cases to verify that the compile of inner-padding works with the triton PR triton-lang/triton#4222. Before the triton PR, the inductor code-gen kernel fails at ``` tmp10 = tl.where(tmp6, tmp8, tmp9) TypeError: unexpected type fp8e5 and fp8e5 ``` Reviewed By: irobert0126 Differential Revision: D62003827
1 parent 9d169a8 commit 0fdfe3e

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

test/float8/test_compile.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ def _test_compile_base(
4040
fullgraph: bool,
4141
config: Float8LinearConfig,
4242
dtype: torch.dtype,
43+
pad_inner_dim: bool,
4344
):
4445
random.seed(0)
4546
torch.manual_seed(0)
4647
x_shape = (16, 16)
48+
if pad_inner_dim:
49+
x_shape = (17, 16)
4750
linear_dtype = torch.bfloat16
4851

4952
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
@@ -114,6 +117,9 @@ def _get_config(
114117
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
115118
)
116119
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
120+
@pytest.mark.parametrize(
121+
"pad_inner_dim", [True, False]
122+
)
117123
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
118124
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
119125
def test_eager_only(
@@ -122,17 +128,20 @@ def test_eager_only(
122128
scaling_type_input: ScalingType,
123129
scaling_type_weight: ScalingType,
124130
scaling_type_grad_output: ScalingType,
131+
pad_inner_dim: bool,
125132
dtype: torch.dtype,
126133
):
127134
torch._dynamo.reset()
128135
config = _get_config(
129136
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
137+
pad_inner_dim=pad_inner_dim,
130138
)
131139
_test_compile_base(
132140
"eager",
133141
fullgraph,
134142
config,
135143
dtype,
144+
pad_inner_dim,
136145
)
137146

138147

@@ -147,6 +156,9 @@ def test_eager_only(
147156
@pytest.mark.parametrize(
148157
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
149158
)
159+
@pytest.mark.parametrize(
160+
"pad_inner_dim", [True, False]
161+
)
150162
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
151163
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
152164
def test_aot_eager(
@@ -155,17 +167,20 @@ def test_aot_eager(
155167
scaling_type_input: ScalingType,
156168
scaling_type_weight: ScalingType,
157169
scaling_type_grad_output: ScalingType,
170+
pad_inner_dim: bool,
158171
dtype: torch.dtype,
159172
):
160173
torch._dynamo.reset()
161174
config = _get_config(
162175
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
176+
pad_inner_dim=pad_inner_dim,
163177
)
164178
_test_compile_base(
165179
"aot_eager",
166180
fullgraph,
167181
config,
168182
dtype,
183+
pad_inner_dim,
169184
)
170185

171186

@@ -180,6 +195,9 @@ def test_aot_eager(
180195
@pytest.mark.parametrize(
181196
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
182197
)
198+
@pytest.mark.parametrize(
199+
"pad_inner_dim", [False, True]
200+
)
183201
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
184202
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
185203
def test_inductor(
@@ -188,17 +206,20 @@ def test_inductor(
188206
scaling_type_input: ScalingType,
189207
scaling_type_weight: ScalingType,
190208
scaling_type_grad_output: ScalingType,
209+
pad_inner_dim: bool,
191210
dtype: torch.dtype,
192211
):
193212
torch._dynamo.reset()
194213
config = _get_config(
195214
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
215+
pad_inner_dim=pad_inner_dim,
196216
)
197217
_test_compile_base(
198218
"inductor",
199219
fullgraph,
200220
config,
201221
dtype,
222+
pad_inner_dim,
202223
)
203224

204225

0 commit comments

Comments
 (0)