@@ -104,8 +104,8 @@ def test_json_serializable(self):
104
104
"""
105
105
quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 32 , layout = TensorCoreTiledLayout ())
106
106
d = quantization_config .to_dict ()
107
- self .assertIsInstance (d ["quant_type_kwargs" ]["layout" ], dict )
108
- self .assertTrue ("inner_k_tiles" in d ["quant_type_kwargs" ]["layout" ])
107
+ self .assertIsInstance (d ["quant_type_kwargs" ]["layout" ], list )
108
+ self .assertTrue ("inner_k_tiles" in d ["quant_type_kwargs" ]["layout" ][ 1 ] )
109
109
quantization_config .to_json_string (use_diff = False )
110
110
111
111
@@ -159,7 +159,7 @@ def test_int4wo_quant_bfloat16_conversion(self):
159
159
# Note: we quantize the bfloat16 model on the fly to int4
160
160
quantized_model = AutoModelForCausalLM .from_pretrained (
161
161
self .model_name ,
162
- torch_dtype = None ,
162
+ torch_dtype = torch . bfloat16 ,
163
163
device_map = self .device ,
164
164
quantization_config = quant_config ,
165
165
)
@@ -282,7 +282,7 @@ def test_autoquant(self):
282
282
283
283
quantized_model = AutoModelForCausalLM .from_pretrained (
284
284
self .model_name ,
285
- torch_dtype = torch . bfloat16 ,
285
+ torch_dtype = "auto" ,
286
286
device_map = self .device ,
287
287
quantization_config = quant_config ,
288
288
)
@@ -295,7 +295,7 @@ def test_autoquant(self):
295
295
296
296
check_autoquantized (self , quantized_model .model .layers [0 ].self_attn .v_proj )
297
297
298
- EXPECTED_OUTPUT = ' What are we having for dinner?\n \n 10. "Dinner is ready'
298
+ EXPECTED_OUTPUT = " What are we having for dinner?\n \n Jane: (sighs)"
299
299
output = quantized_model .generate (
300
300
** input_ids , max_new_tokens = self .max_new_tokens , cache_implementation = "static"
301
301
)
@@ -307,9 +307,7 @@ def test_autoquant(self):
307
307
class TorchAoSerializationTest (unittest .TestCase ):
308
308
input_text = "What are we having for dinner?"
309
309
max_new_tokens = 10
310
- ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n - 1. What is the temperature outside"
311
- # TODO: investigate why we don't have the same output as the original model for this test
312
- SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
310
+ EXPECTED_OUTPUT = "What are we having for dinner?\n - 1. What is the temperature outside"
313
311
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
314
312
quant_scheme = "int4_weight_only"
315
313
quant_scheme_kwargs = (
@@ -326,9 +324,10 @@ def setUpClass(cls):
326
324
327
325
def setUp (self ):
328
326
self .quant_config = TorchAoConfig (self .quant_scheme , ** self .quant_scheme_kwargs )
327
+ torch_dtype = torch .bfloat16 if self .quant_scheme == "int4_weight_only" else "auto"
329
328
self .quantized_model = AutoModelForCausalLM .from_pretrained (
330
329
self .model_name ,
331
- torch_dtype = torch . bfloat16 ,
330
+ torch_dtype = torch_dtype ,
332
331
device_map = self .device ,
333
332
quantization_config = self .quant_config ,
334
333
)
@@ -342,50 +341,49 @@ def test_original_model_expected_output(self):
342
341
input_ids = self .tokenizer (self .input_text , return_tensors = "pt" ).to (self .device )
343
342
output = self .quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
344
343
345
- self .assertEqual (self .tokenizer .decode (output [0 ], skip_special_tokens = True ), self .ORIGINAL_EXPECTED_OUTPUT )
344
+ self .assertEqual (self .tokenizer .decode (output [0 ], skip_special_tokens = True ), self .EXPECTED_OUTPUT )
346
345
347
346
def check_serialization_expected_output (self , device , expected_output ):
348
347
"""
349
348
Test if we can serialize and load/infer the model again on the same device
350
349
"""
350
+ torch_dtype = torch .bfloat16 if self .quant_scheme == "int4_weight_only" else "auto"
351
351
with tempfile .TemporaryDirectory () as tmpdirname :
352
352
self .quantized_model .save_pretrained (tmpdirname , safe_serialization = False )
353
353
loaded_quantized_model = AutoModelForCausalLM .from_pretrained (
354
- self . model_name , torch_dtype = torch . bfloat16 , device_map = device
354
+ tmpdirname , torch_dtype = torch_dtype , device_map = device
355
355
)
356
356
input_ids = self .tokenizer (self .input_text , return_tensors = "pt" ).to (device )
357
357
358
358
output = loaded_quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
359
359
self .assertEqual (self .tokenizer .decode (output [0 ], skip_special_tokens = True ), expected_output )
360
360
361
361
def test_serialization_expected_output (self ):
362
- self .check_serialization_expected_output (self .device , self .SERIALIZED_EXPECTED_OUTPUT )
362
+ self .check_serialization_expected_output (self .device , self .EXPECTED_OUTPUT )
363
363
364
364
365
365
class TorchAoSerializationW8A8CPUTest (TorchAoSerializationTest ):
366
366
quant_scheme , quant_scheme_kwargs = "int8_dynamic_activation_int8_weight" , {}
367
- ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
368
- SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
367
+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
369
368
370
369
@require_torch_gpu
371
370
def test_serialization_expected_output_on_cuda (self ):
372
371
"""
373
372
Test if we can serialize on device (cpu) and load/infer the model on cuda
374
373
"""
375
- self .check_serialization_expected_output ("cuda" , self .SERIALIZED_EXPECTED_OUTPUT )
374
+ self .check_serialization_expected_output ("cuda" , self .EXPECTED_OUTPUT )
376
375
377
376
378
377
class TorchAoSerializationW8CPUTest (TorchAoSerializationTest ):
379
378
quant_scheme , quant_scheme_kwargs = "int8_weight_only" , {}
380
- ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
381
- SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
379
+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
382
380
383
381
@require_torch_gpu
384
382
def test_serialization_expected_output_on_cuda (self ):
385
383
"""
386
384
Test if we can serialize on device (cpu) and load/infer the model on cuda
387
385
"""
388
- self .check_serialization_expected_output ("cuda" , self .SERIALIZED_EXPECTED_OUTPUT )
386
+ self .check_serialization_expected_output ("cuda" , self .EXPECTED_OUTPUT )
389
387
390
388
391
389
@require_torch_gpu
@@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest):
397
395
@require_torch_gpu
398
396
class TorchAoSerializationW8A8GPUTest (TorchAoSerializationTest ):
399
397
quant_scheme , quant_scheme_kwargs = "int8_dynamic_activation_int8_weight" , {}
400
- ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
401
- SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
398
+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
402
399
device = "cuda:0"
403
400
404
401
405
402
@require_torch_gpu
406
403
class TorchAoSerializationW8GPUTest (TorchAoSerializationTest ):
407
404
quant_scheme , quant_scheme_kwargs = "int8_weight_only" , {}
408
- ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
409
- SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
405
+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
410
406
device = "cuda:0"
411
407
412
408
413
409
@require_torch_gpu
414
410
@require_torchao_version_greater_or_equal ("0.10.0" )
415
411
class TorchAoSerializationFP8GPUTest (TorchAoSerializationTest ):
416
- ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
417
- SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
412
+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
418
413
device = "cuda:0"
419
414
420
- def setUp (self ):
415
+ # called only once for all test in this class
416
+ @classmethod
417
+ def setUpClass (cls ):
421
418
if not torch .cuda .is_available () or torch .cuda .get_device_capability ()[0 ] < 9 :
422
419
raise unittest .SkipTest ("CUDA compute capability 9.0 or higher required for FP8 tests" )
423
420
424
421
from torchao .quantization import Float8WeightOnlyConfig
425
422
426
- self .quant_scheme = Float8WeightOnlyConfig ()
427
- self .quant_scheme_kwargs = {}
428
- super ().setUp ()
423
+ cls .quant_scheme = Float8WeightOnlyConfig ()
424
+ cls .quant_scheme_kwargs = {}
425
+
426
+ super ().setUpClass ()
429
427
430
428
431
429
@require_torch_gpu
432
430
@require_torchao_version_greater_or_equal ("0.10.0" )
433
431
class TorchAoSerializationA8W4Test (TorchAoSerializationTest ):
434
- ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
435
- SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
432
+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
436
433
device = "cuda:0"
437
434
438
- def setUp (self ):
435
+ # called only once for all test in this class
436
+ @classmethod
437
+ def setUpClass (cls ):
439
438
if not torch .cuda .is_available () or torch .cuda .get_device_capability ()[0 ] < 9 :
440
439
raise unittest .SkipTest ("CUDA compute capability 9.0 or higher required for FP8 tests" )
441
440
442
441
from torchao .quantization import Int8DynamicActivationInt4WeightConfig
443
442
444
- self .quant_scheme = Int8DynamicActivationInt4WeightConfig ()
445
- self .quant_scheme_kwargs = {}
446
- super ().setUp ()
443
+ cls .quant_scheme = Int8DynamicActivationInt4WeightConfig ()
444
+ cls .quant_scheme_kwargs = {}
445
+
446
+ super ().setUpClass ()
447
447
448
448
449
449
if __name__ == "__main__" :
0 commit comments