@@ -40,10 +40,13 @@ def _test_compile_base(
40
40
fullgraph : bool ,
41
41
config : Float8LinearConfig ,
42
42
dtype : torch .dtype ,
43
+ pad_inner_dim : bool ,
43
44
):
44
45
random .seed (0 )
45
46
torch .manual_seed (0 )
46
47
x_shape = (16 , 16 )
48
+ if pad_inner_dim :
49
+ x_shape = (17 , 16 )
47
50
linear_dtype = torch .bfloat16
48
51
49
52
x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
@@ -70,6 +73,7 @@ def _get_config(
70
73
scaling_type_input ,
71
74
scaling_type_weight ,
72
75
scaling_type_grad_output ,
76
+ pad_inner_dim ,
73
77
emulate ,
74
78
):
75
79
if scaling_type_input is ScalingType .STATIC :
@@ -99,6 +103,7 @@ def _get_config(
99
103
cast_config_weight = cast_config_weight ,
100
104
cast_config_grad_output = cast_config_grad_output ,
101
105
emulate = emulate ,
106
+ pad_inner_dim = pad_inner_dim ,
102
107
)
103
108
return config
104
109
@@ -114,6 +119,9 @@ def _get_config(
114
119
"scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
115
120
)
116
121
@pytest .mark .parametrize ("emulate" , [False , True ] if is_cuda_8_9 else [True ])
122
+ @pytest .mark .parametrize (
123
+ "pad_inner_dim" , [True , False ]
124
+ )
117
125
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
118
126
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
119
127
def test_eager_only (
@@ -122,17 +130,19 @@ def test_eager_only(
122
130
scaling_type_input : ScalingType ,
123
131
scaling_type_weight : ScalingType ,
124
132
scaling_type_grad_output : ScalingType ,
133
+ pad_inner_dim : bool ,
125
134
dtype : torch .dtype ,
126
135
):
127
136
torch ._dynamo .reset ()
128
137
config = _get_config (
129
- scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,
138
+ scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,pad_inner_dim ,
130
139
)
131
140
_test_compile_base (
132
141
"eager" ,
133
142
fullgraph ,
134
143
config ,
135
144
dtype ,
145
+ pad_inner_dim ,
136
146
)
137
147
138
148
@@ -147,6 +157,9 @@ def test_eager_only(
147
157
@pytest .mark .parametrize (
148
158
"scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
149
159
)
160
+ @pytest .mark .parametrize (
161
+ "pad_inner_dim" , [True , False ]
162
+ )
150
163
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
151
164
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
152
165
def test_aot_eager (
@@ -155,17 +168,19 @@ def test_aot_eager(
155
168
scaling_type_input : ScalingType ,
156
169
scaling_type_weight : ScalingType ,
157
170
scaling_type_grad_output : ScalingType ,
171
+ pad_inner_dim : bool ,
158
172
dtype : torch .dtype ,
159
173
):
160
174
torch ._dynamo .reset ()
161
175
config = _get_config (
162
- scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,
176
+ scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,pad_inner_dim ,
163
177
)
164
178
_test_compile_base (
165
179
"aot_eager" ,
166
180
fullgraph ,
167
181
config ,
168
182
dtype ,
183
+ pad_inner_dim ,
169
184
)
170
185
171
186
@@ -180,6 +195,9 @@ def test_aot_eager(
180
195
@pytest .mark .parametrize (
181
196
"scaling_type_grad_output" , [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ]
182
197
)
198
+ @pytest .mark .parametrize (
199
+ "pad_inner_dim" , [False , True ]
200
+ )
183
201
@unittest .skipIf (not torch .cuda .is_available () or not is_cuda_8_9 , "CUDA with float8 support not available" )
184
202
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
185
203
def test_inductor (
@@ -188,17 +206,19 @@ def test_inductor(
188
206
scaling_type_input : ScalingType ,
189
207
scaling_type_weight : ScalingType ,
190
208
scaling_type_grad_output : ScalingType ,
209
+ pad_inner_dim : bool ,
191
210
dtype : torch .dtype ,
192
211
):
193
212
torch ._dynamo .reset ()
194
213
config = _get_config (
195
- scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate ,
214
+ scaling_type_input , scaling_type_weight , scaling_type_grad_output , emulate , pad_inner_dim ,
196
215
)
197
216
_test_compile_base (
198
217
"inductor" ,
199
218
fullgraph ,
200
219
config ,
201
220
dtype ,
221
+ pad_inner_dim ,
202
222
)
203
223
204
224
0 commit comments