@@ -171,7 +171,7 @@ def test_rms_norm(self):
171
171
x = torch .randn (16 , 64 )
172
172
out = rms_norm (x )
173
173
self .assertEqual (out .shape , (16 , 64 ))
174
-
174
+
175
175
# Test with different eps
176
176
rms_norm = RMSNorm (dim = 64 , eps = 1e-5 )
177
177
out = rms_norm (x )
@@ -184,38 +184,50 @@ def test_rms_norm_linear_activation(self):
184
184
out = model (x )
185
185
self .assertEqual (out .shape , (16 , 32 ))
186
186
self .assertEqual (out .dtype , torch .float32 )
187
-
187
+
188
188
# Test with ReLU activation
189
- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "relu" )
189
+ model = RMSNormLinearActivation (
190
+ fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "relu"
191
+ )
190
192
out = model (x )
191
193
self .assertEqual (out .shape , (16 , 32 ))
192
194
self .assertTrue (torch .all (out >= 0 )) # Check ReLU output range
193
-
195
+
194
196
# Test with SiLU activation
195
- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "silu" )
197
+ model = RMSNormLinearActivation (
198
+ fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "silu"
199
+ )
196
200
out = model (x )
197
201
self .assertEqual (out .shape , (16 , 32 ))
198
-
202
+
199
203
# Test with invalid activation
200
204
with self .assertRaises (ValueError ):
201
- RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "invalid" )
205
+ RMSNormLinearActivation (
206
+ fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "invalid"
207
+ )
202
208
203
209
def test_transformer_block (self ):
204
210
# Test with default parameters
205
- model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
211
+ model = TransformerBlock (
212
+ hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32
213
+ )
206
214
x = torch .randn (16 , 16 , 64 ) # [batch_size, seq_len, hidden_dim]
207
215
out = model (x )
208
216
self .assertEqual (out .shape , (16 , 16 , 64 ))
209
217
self .assertEqual (out .dtype , torch .float32 )
210
-
218
+
211
219
# Test with different parameters
212
- model = TransformerBlock (hidden_dim = 128 , num_heads = 4 , mlp_ratio = 2 , dtype = torch .float32 )
220
+ model = TransformerBlock (
221
+ hidden_dim = 128 , num_heads = 4 , mlp_ratio = 2 , dtype = torch .float32
222
+ )
213
223
x = torch .randn (8 , 32 , 128 )
214
224
out = model (x )
215
225
self .assertEqual (out .shape , (8 , 32 , 128 ))
216
-
226
+
217
227
# Test with different head dimensions
218
- model = TransformerBlock (hidden_dim = 96 , num_heads = 6 , mlp_ratio = 3 , dtype = torch .float32 )
228
+ model = TransformerBlock (
229
+ hidden_dim = 96 , num_heads = 6 , mlp_ratio = 3 , dtype = torch .float32
230
+ )
219
231
x = torch .randn (4 , 8 , 96 )
220
232
out = model (x )
221
233
self .assertEqual (out .shape , (4 , 8 , 96 ))
@@ -255,7 +267,7 @@ def test_create_model_and_input(self):
255
267
)
256
268
self .assertIsInstance (model , RMSNormLinearActivation )
257
269
self .assertEqual (input_data .shape , (m , k ))
258
-
270
+
259
271
# Test TransformerBlock
260
272
model , input_data = create_model_and_input (
261
273
model_type = "transformer_block" ,
@@ -266,40 +278,50 @@ def test_create_model_and_input(self):
266
278
device = "cpu" ,
267
279
)
268
280
self .assertIsInstance (model , TransformerBlock )
269
- self .assertEqual (input_data .shape , (m , 16 , k )) # [batch_size, seq_len, hidden_dim]
281
+ self .assertEqual (
282
+ input_data .shape , (m , 16 , k )
283
+ ) # [batch_size, seq_len, hidden_dim]
270
284
271
285
def test_quantization_on_models (self ):
272
286
# Test quantization on RMSNormLinearActivation
273
287
model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
274
288
x = torch .randn (16 , 64 )
275
-
289
+
276
290
# Test with Int8WeightOnlyConfig
277
291
config = string_to_config (quantization = "int8wo" , sparsity = None )
278
292
if config is not None :
279
293
# Skip quantization test if torchao.quantization.quantize is not available
280
294
try :
281
295
from torchao .quantization import quantize
296
+
282
297
quantized_model = quantize (model , config )
283
298
out = quantized_model (x )
284
299
self .assertEqual (out .shape , (16 , 32 ))
285
300
except ImportError :
286
- print ("Skipping quantization test: torchao.quantization.quantize not available" )
287
-
301
+ print (
302
+ "Skipping quantization test: torchao.quantization.quantize not available"
303
+ )
304
+
288
305
# Test quantization on TransformerBlock
289
- model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
306
+ model = TransformerBlock (
307
+ hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32
308
+ )
290
309
x = torch .randn (16 , 16 , 64 )
291
-
310
+
292
311
# Test with Int8WeightOnlyConfig
293
312
config = string_to_config (quantization = "int8wo" , sparsity = None )
294
313
if config is not None :
295
314
# Skip quantization test if torchao.quantization.quantize is not available
296
315
try :
297
316
from torchao .quantization import quantize
317
+
298
318
quantized_model = quantize (model , config )
299
319
out = quantized_model (x )
300
320
self .assertEqual (out .shape , (16 , 16 , 64 ))
301
321
except ImportError :
302
- print ("Skipping quantization test: torchao.quantization.quantize not available" )
322
+ print (
323
+ "Skipping quantization test: torchao.quantization.quantize not available"
324
+ )
303
325
304
326
def test_generate_results_csv (self ):
305
327
results = [
0 commit comments