12
12
13
13
import torch
14
14
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
15
+ from torchao .dtypes import (
16
+ TensorCoreTiledLayoutType ,
17
+ )
18
+ from torchao .quantization .prototype .qat .affine_fake_quantized_tensor import (
19
+ AffineFakeQuantizedTensor ,
20
+ )
15
21
from torchao .quantization .prototype .qat .utils import (
16
22
_choose_qparams_per_token_asymmetric ,
17
23
_fake_quantize_per_channel_group ,
18
24
_fake_quantize_per_token ,
19
25
_GenericFakeQuantize ,
26
+ _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK ,
27
+ )
28
+ from torchao .quantization .quant_api import (
29
+ int4_weight_only ,
30
+ quantize_ ,
20
31
)
21
32
from torchao .quantization .quant_primitives import (
22
33
fake_quantize_affine ,
@@ -190,6 +201,7 @@ def test_qat_8da4w_linear(self):
190
201
ptq_out = ptq_linear (x2 )
191
202
torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
192
203
204
+ # TODO: compare against quantize_ API instead
193
205
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
194
206
def test_qat_8da4w_quantizer (self ):
195
207
from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
@@ -217,13 +229,6 @@ def test_qat_8da4w_quantizer(self):
217
229
converted_out = converted_model (* x )
218
230
torch .testing .assert_close (ptq_out , converted_out , atol = 0 , rtol = 0 )
219
231
220
- # Compare converted state dict
221
- ptq_state_dict = ptq_model .state_dict ()
222
- converted_state_dict = converted_model .state_dict ()
223
- self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
224
- for k in ptq_state_dict .keys ():
225
- torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
226
-
227
232
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
228
233
def test_qat_8da4w_quantizer_meta_weights (self ):
229
234
from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
@@ -236,6 +241,20 @@ def test_qat_8da4w_quantizer_meta_weights(self):
236
241
qat_model = qat_quantizer .prepare (m )
237
242
self .assertTrue (all (v .is_meta for v in qat_model .state_dict ().values ()))
238
243
244
+ def _copy_subclass_weights (
245
+ self ,
246
+ nn_linear : torch .nn .Linear ,
247
+ subclass_linear : AffineFakeQuantizedTensor ,
248
+ ):
249
+ nn_linear .weight = torch .nn .Parameter (subclass_linear .weight .original_tensor )
250
+
251
+ def _assert_matches_subclass_weights (
252
+ self ,
253
+ nn_linear : torch .nn .Linear ,
254
+ subclass_linear : AffineFakeQuantizedTensor ,
255
+ ):
256
+ torch .testing .assert_close (nn_linear .weight , subclass_linear .weight .original_tensor , atol = 0 , rtol = 0 )
257
+
239
258
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
240
259
def test_qat_8da4w_quantizer_disable_fake_quant (self ):
241
260
"""
@@ -247,6 +266,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
247
266
enable_8da4w_fake_quant ,
248
267
)
249
268
269
+ def assert_fake_quant_enabled (m : torch .nn .Linear , enabled : bool ):
270
+ self .assertTrue (isinstance (m .weight , AffineFakeQuantizedTensor ))
271
+ self .assertEqual (m .weight .fake_quant_enabled , enabled )
272
+ self .assertTrue (hasattr (m , _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK ))
273
+ (_ , handle ) = getattr (m , _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK )
274
+ if enabled :
275
+ self .assertIsNotNone (handle )
276
+ else :
277
+ self .assertIsNone (handle )
278
+
250
279
group_size = 16
251
280
torch .manual_seed (self .SEED )
252
281
m = M ()
@@ -255,14 +284,14 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
255
284
quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
256
285
qat_model = quantizer .prepare (m )
257
286
qat_model .apply (disable_8da4w_fake_quant )
258
- self . assertFalse (qat_model .linear1 . _fake_quant_enabled )
259
- self . assertFalse (qat_model .linear2 . _fake_quant_enabled )
260
- self . assertFalse (qat_model .sub .linear . _fake_quant_enabled )
287
+ assert_fake_quant_enabled (qat_model .linear1 , enabled = False )
288
+ assert_fake_quant_enabled (qat_model .linear2 , enabled = False )
289
+ assert_fake_quant_enabled (qat_model .sub .linear , enabled = False )
261
290
262
291
# Disabled fake quant is just a normal linear
263
- m2 .linear1 . weight = qat_model .linear1 . weight
264
- m2 .linear2 . weight = qat_model .linear2 . weight
265
- m2 .sub .linear . weight = qat_model .sub .linear . weight
292
+ self . _copy_subclass_weights ( m2 .linear1 , qat_model .linear1 )
293
+ self . _copy_subclass_weights ( m2 .linear2 , qat_model .linear2 )
294
+ self . _copy_subclass_weights ( m2 .sub .linear , qat_model .sub .linear )
266
295
torch .manual_seed (self .SEED )
267
296
x = m .example_inputs ()
268
297
x2 = copy .deepcopy (x )
@@ -272,16 +301,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
272
301
273
302
# Renable fake quant
274
303
qat_model .apply (enable_8da4w_fake_quant )
275
- self . assertTrue (qat_model .linear1 . _fake_quant_enabled )
276
- self . assertTrue (qat_model .linear2 . _fake_quant_enabled )
277
- self . assertTrue (qat_model .sub .linear . _fake_quant_enabled )
304
+ assert_fake_quant_enabled (qat_model .linear1 , enabled = True )
305
+ assert_fake_quant_enabled (qat_model .linear2 , enabled = True )
306
+ assert_fake_quant_enabled (qat_model .sub .linear , enabled = True )
278
307
279
308
# Fake quant should be applied as normal
280
309
quantizer2 = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
281
310
qat_model2 = quantizer2 .prepare (m3 )
282
- qat_model2 .linear1 .weight = qat_model .linear1 .weight
283
- qat_model2 .linear2 .weight = qat_model .linear2 .weight
284
- qat_model2 .sub .linear .weight = qat_model .sub .linear .weight
311
+ qat_model2 .linear1 .weight . original_tensor = qat_model .linear1 .weight . original_tensor
312
+ qat_model2 .linear2 .weight . original_tensor = qat_model .linear2 .weight . original_tensor
313
+ qat_model2 .sub .linear .weight . original_tensor = qat_model .sub .linear .weight . original_tensor
285
314
torch .manual_seed (self .SEED )
286
315
x = m .example_inputs ()
287
316
x2 = copy .deepcopy (x )
@@ -306,9 +335,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
306
335
quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
307
336
qat_model = quantizer .prepare (m )
308
337
qat_model .apply (disable_8da4w_fake_quant )
309
- nn_model .linear1 . weight = qat_model .linear1 . weight
310
- nn_model .linear2 . weight = qat_model .linear2 . weight
311
- nn_model .sub .linear . weight = qat_model .sub .linear . weight
338
+ self . _copy_subclass_weights ( nn_model .linear1 , qat_model .linear1 )
339
+ self . _copy_subclass_weights ( nn_model .linear2 , qat_model .linear2 )
340
+ self . _copy_subclass_weights ( nn_model .sub .linear , qat_model .sub .linear )
312
341
313
342
# Simulate training for both models
314
343
optimizer1 = torch .optim .SGD (nn_model .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5 )
@@ -330,9 +359,55 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
330
359
optimizer2 .step ()
331
360
332
361
# After 1 training step, weights should match exactly
333
- torch .testing .assert_close (nn_model .linear1 .weight , qat_model .linear1 .weight , atol = 0 , rtol = 0 )
334
- torch .testing .assert_close (nn_model .linear2 .weight , qat_model .linear2 .weight , atol = 0 , rtol = 0 )
335
- torch .testing .assert_close (nn_model .sub .linear .weight , qat_model .sub .linear .weight , atol = 0 , rtol = 0 )
362
+ self ._assert_matches_subclass_weights (nn_model .linear1 , qat_model .linear1 )
363
+ self ._assert_matches_subclass_weights (nn_model .linear2 , qat_model .linear2 )
364
+ self ._assert_matches_subclass_weights (nn_model .sub .linear , qat_model .sub .linear )
365
+
366
+ def _test_qat_quantized_gradients (self , quantizer ):
367
+ """
368
+ Test that QAT produces gradients in the backward pass.
369
+ """
370
+ num_steps = 10
371
+ torch .manual_seed (self .SEED )
372
+ m = M ()
373
+ model = quantizer .prepare (m )
374
+ optimizer = torch .optim .SGD (model .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5 )
375
+ loss_fn = torch .nn .CrossEntropyLoss ()
376
+
377
+ # Simulate training
378
+ current_step = 0
379
+ last_linear1_grad = None
380
+ last_linear2_grad = None
381
+ last_sub_linear_grad = None
382
+ while current_step < num_steps :
383
+ example_inputs = model .example_inputs ()
384
+ target = torch .randn (1 , 512 ).float ()
385
+ output = model (* example_inputs )
386
+ loss = loss_fn (output , target )
387
+ loss .backward ()
388
+ # assert each linear grad is updated
389
+ new_linear1_grad = model .linear1 .weight .grad
390
+ new_linear2_grad = model .linear2 .weight .grad
391
+ new_sub_linear_grad = model .sub .linear .weight .grad
392
+ self .assertIsNotNone (new_linear1_grad )
393
+ self .assertIsNotNone (new_linear2_grad )
394
+ self .assertIsNotNone (new_sub_linear_grad )
395
+ if current_step > 0 :
396
+ self .assertFalse (torch .equal (last_linear1_grad , new_linear1_grad ))
397
+ self .assertFalse (torch .equal (last_linear2_grad , new_linear2_grad ))
398
+ self .assertFalse (torch .equal (last_sub_linear_grad , new_sub_linear_grad ))
399
+ last_linear1_grad = new_linear1_grad
400
+ last_linear2_grad = new_linear2_grad
401
+ last_sub_linear_grad = new_sub_linear_grad
402
+ optimizer .zero_grad ()
403
+ optimizer .step ()
404
+ current_step += 1
405
+
406
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
407
+ def test_qat_8da4w_quantizer_gradients (self ):
408
+ from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
409
+ quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = 16 )
410
+ self ._test_qat_quantized_gradients (quantizer )
336
411
337
412
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
338
413
def test_qat_generic_fake_quantize (self ):
@@ -353,7 +428,7 @@ def test_qat_generic_fake_quantize(self):
353
428
block_size = (1 , ao_input .shape [- 1 ])
354
429
ao_s = copy .deepcopy (py_s )
355
430
ao_zp = copy .deepcopy (py_zp )
356
- ao_out = _GenericFakeQuantize .apply (ao_input , ao_s , ao_zp , qmin , qmax , block_size )
431
+ ao_out = _GenericFakeQuantize .apply (ao_input , block_size , ao_s , ao_zp , qmin , qmax )
357
432
ao_out .sum ().backward ()
358
433
359
434
torch .testing .assert_close (py_out , ao_out , atol = 0 , rtol = 0 )
@@ -373,10 +448,7 @@ def _assert_close_4w(self, val, ref):
373
448
print (mean_err )
374
449
self .assertTrue (mean_err < 0.05 )
375
450
376
- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
377
451
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
378
- # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
379
- @unittest .skipIf (TORCH_VERSION_AT_LEAST_2_5 , "int4 doesn't work for 2.5+ right now" )
380
452
def test_qat_4w_primitives (self ):
381
453
n_bit = 4
382
454
group_size = 32
@@ -464,11 +536,9 @@ def test_qat_4w_quantizer(self):
464
536
qat_quantizer = Int4WeightOnlyQATQuantizer (
465
537
groupsize = group_size , inner_k_tiles = inner_k_tiles ,
466
538
)
467
- ptq_quantizer = Int4WeightOnlyQuantizer (
468
- groupsize = group_size , inner_k_tiles = inner_k_tiles ,
469
- )
470
539
qat_model = qat_quantizer .prepare (m )
471
- ptq_model = ptq_quantizer .quantize (m2 )
540
+ ptq_model = m2
541
+ quantize_ (ptq_model , int4_weight_only (group_size , TensorCoreTiledLayoutType (inner_k_tiles )))
472
542
473
543
# Compare model values
474
544
torch .manual_seed (self .SEED )
@@ -483,12 +553,11 @@ def test_qat_4w_quantizer(self):
483
553
converted_out = converted_model (* x )
484
554
torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
485
555
486
- # Compare converted state dict
487
- ptq_state_dict = ptq_model .state_dict ()
488
- converted_state_dict = converted_model .state_dict ()
489
- self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
490
- for k in ptq_state_dict .keys ():
491
- torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
556
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
557
+ def test_qat_4w_quantizer_gradients (self ):
558
+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
559
+ quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
560
+ self ._test_qat_quantized_gradients (quantizer )
492
561
493
562
494
563
if __name__ == "__main__" :
0 commit comments