@@ -40,10 +40,13 @@ def _test_compile_base(
40
40
fullgraph : bool ,
41
41
config : Float8LinearConfig ,
42
42
dtype : torch .dtype ,
43
+ pad_inner_dim : bool ,
43
44
):
44
45
random .seed (0 )
45
46
torch .manual_seed (0 )
46
47
x_shape = (16 , 16 )
48
+ if pad_inner_dim :
49
+ x_shape = (17 , 16 )
47
50
linear_dtype = torch .bfloat16
48
51
49
52
x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
@@ -114,6 +117,9 @@ def _get_config(
114
117
"scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
115
118
)
116
119
@pytest .mark .parametrize ("emulate" , [False , True ] if is_cuda_8_9 else [True ])
120
+ @pytest .mark .parametrize (
121
+ "pad_inner_dim" , [True , False ]
122
+ )
117
123
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
118
124
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
119
125
def test_eager_only (
@@ -122,17 +128,20 @@ def test_eager_only(
122
128
scaling_type_input : ScalingType ,
123
129
scaling_type_weight : ScalingType ,
124
130
scaling_type_grad_output : ScalingType ,
131
+ pad_inner_dim : bool ,
125
132
dtype : torch .dtype ,
126
133
):
127
134
torch ._dynamo .reset ()
128
135
config = _get_config (
129
136
scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,
137
+ pad_inner_dim = pad_inner_dim ,
130
138
)
131
139
_test_compile_base (
132
140
"eager" ,
133
141
fullgraph ,
134
142
config ,
135
143
dtype ,
144
+ pad_inner_dim ,
136
145
)
137
146
138
147
@@ -147,6 +156,9 @@ def test_eager_only(
147
156
@pytest .mark .parametrize (
148
157
"scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
149
158
)
159
+ @pytest .mark .parametrize (
160
+ "pad_inner_dim" , [True , False ]
161
+ )
150
162
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
151
163
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
152
164
def test_aot_eager (
@@ -155,17 +167,20 @@ def test_aot_eager(
155
167
scaling_type_input : ScalingType ,
156
168
scaling_type_weight : ScalingType ,
157
169
scaling_type_grad_output : ScalingType ,
170
+ pad_inner_dim : bool ,
158
171
dtype : torch .dtype ,
159
172
):
160
173
torch ._dynamo .reset ()
161
174
config = _get_config (
162
175
scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,
176
+ pad_inner_dim = pad_inner_dim ,
163
177
)
164
178
_test_compile_base (
165
179
"aot_eager" ,
166
180
fullgraph ,
167
181
config ,
168
182
dtype ,
183
+ pad_inner_dim ,
169
184
)
170
185
171
186
@@ -180,6 +195,9 @@ def test_aot_eager(
180
195
@pytest .mark .parametrize (
181
196
"scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
182
197
)
198
+ @pytest .mark .parametrize (
199
+ "pad_inner_dim" , [False , True ]
200
+ )
183
201
@unittest .skipIf (not torch .cuda .is_available () or not is_cuda_8_9 , "CUDA with float8 support not available" )
184
202
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
185
203
def test_inductor (
@@ -188,17 +206,20 @@ def test_inductor(
188
206
scaling_type_input : ScalingType ,
189
207
scaling_type_weight : ScalingType ,
190
208
scaling_type_grad_output : ScalingType ,
209
+ pad_inner_dim : bool ,
191
210
dtype : torch .dtype ,
192
211
):
193
212
torch ._dynamo .reset ()
194
213
config = _get_config (
195
214
scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,
215
+ pad_inner_dim = pad_inner_dim ,
196
216
)
197
217
_test_compile_base (
198
218
"inductor" ,
199
219
fullgraph ,
200
220
config ,
201
221
dtype ,
222
+ pad_inner_dim ,
202
223
)
203
224
204
225
0 commit comments