Skip to content

Commit 2c8e3f3

Browse files
authored
Refactor QAT to use tensor subclasses (#585)
This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce `AffineFakeQuantizedTensor`, which is analogous to `AffineQuantizedTensor` but applies fake quantization instead and requires gradient updates. `AffineFakeQuantizedTensor` wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (`AffineFakeQuantizedTensor`) and never to the inner tensor. For weights, the outer tensor is also a `torch.nn.Parameter`, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like `aten.add_` and `aten.mul_`. An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module `forward_pre_hook` instead of relying on another subclass `LinearActivationQuantizedTensor` that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under `__torch_dispatch__`, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a `forward_pre_hook` that wraps the input activations before each call to forward. This approach is also motivated by how [DTensor wraps their subclasses ](https://github.com/pytorch/pytorch/blob/844103197d3e8cf6b4b59176e473365113f4f962/torch/distributed/tensor/parallel/style.py#L521). - [x] Add AffineFakeQuantizedTensor - [x] Add support for int4 weight only fake quantize - [x] Add support for int8 dynamic activations + int4 weight fake quantize (8da4w) - [x] Add prepare and convert path to int4 QAT quantizer - [x] Add prepare and convert path to 8da4w QAT quantizer - [x] Support enabling and disabling fake quant dynamically - [x] Support `__repr__` in AffineFakeQuantizedTensor - [x] Fix backward pass for int4 weight only - [x] Fix backward pass for int8 dynamic activations + int4 weight
1 parent 1909171 commit 2c8e3f3

File tree

6 files changed

+711
-142
lines changed

6 files changed

+711
-142
lines changed

test/quantization/test_qat.py

Lines changed: 108 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,22 @@
1212

1313
import torch
1414
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
15+
from torchao.dtypes import (
16+
TensorCoreTiledLayoutType,
17+
)
18+
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
19+
AffineFakeQuantizedTensor,
20+
)
1521
from torchao.quantization.prototype.qat.utils import (
1622
_choose_qparams_per_token_asymmetric,
1723
_fake_quantize_per_channel_group,
1824
_fake_quantize_per_token,
1925
_GenericFakeQuantize,
26+
_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK,
27+
)
28+
from torchao.quantization.quant_api import (
29+
int4_weight_only,
30+
quantize_,
2031
)
2132
from torchao.quantization.quant_primitives import (
2233
fake_quantize_affine,
@@ -190,6 +201,7 @@ def test_qat_8da4w_linear(self):
190201
ptq_out = ptq_linear(x2)
191202
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
192203

204+
# TODO: compare against quantize_ API instead
193205
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
194206
def test_qat_8da4w_quantizer(self):
195207
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
@@ -217,13 +229,6 @@ def test_qat_8da4w_quantizer(self):
217229
converted_out = converted_model(*x)
218230
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)
219231

220-
# Compare converted state dict
221-
ptq_state_dict = ptq_model.state_dict()
222-
converted_state_dict = converted_model.state_dict()
223-
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
224-
for k in ptq_state_dict.keys():
225-
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
226-
227232
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
228233
def test_qat_8da4w_quantizer_meta_weights(self):
229234
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
@@ -236,6 +241,20 @@ def test_qat_8da4w_quantizer_meta_weights(self):
236241
qat_model = qat_quantizer.prepare(m)
237242
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))
238243

244+
def _copy_subclass_weights(
245+
self,
246+
nn_linear: torch.nn.Linear,
247+
subclass_linear: AffineFakeQuantizedTensor,
248+
):
249+
nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor)
250+
251+
def _assert_matches_subclass_weights(
252+
self,
253+
nn_linear: torch.nn.Linear,
254+
subclass_linear: AffineFakeQuantizedTensor,
255+
):
256+
torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0)
257+
239258
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
240259
def test_qat_8da4w_quantizer_disable_fake_quant(self):
241260
"""
@@ -247,6 +266,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
247266
enable_8da4w_fake_quant,
248267
)
249268

269+
def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
270+
self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor))
271+
self.assertEqual(m.weight.fake_quant_enabled, enabled)
272+
self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK))
273+
(_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)
274+
if enabled:
275+
self.assertIsNotNone(handle)
276+
else:
277+
self.assertIsNone(handle)
278+
250279
group_size = 16
251280
torch.manual_seed(self.SEED)
252281
m = M()
@@ -255,14 +284,14 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
255284
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
256285
qat_model = quantizer.prepare(m)
257286
qat_model.apply(disable_8da4w_fake_quant)
258-
self.assertFalse(qat_model.linear1._fake_quant_enabled)
259-
self.assertFalse(qat_model.linear2._fake_quant_enabled)
260-
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
287+
assert_fake_quant_enabled(qat_model.linear1, enabled=False)
288+
assert_fake_quant_enabled(qat_model.linear2, enabled=False)
289+
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)
261290

262291
# Disabled fake quant is just a normal linear
263-
m2.linear1.weight = qat_model.linear1.weight
264-
m2.linear2.weight = qat_model.linear2.weight
265-
m2.sub.linear.weight = qat_model.sub.linear.weight
292+
self._copy_subclass_weights(m2.linear1, qat_model.linear1)
293+
self._copy_subclass_weights(m2.linear2, qat_model.linear2)
294+
self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear)
266295
torch.manual_seed(self.SEED)
267296
x = m.example_inputs()
268297
x2 = copy.deepcopy(x)
@@ -272,16 +301,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
272301

273302
# Renable fake quant
274303
qat_model.apply(enable_8da4w_fake_quant)
275-
self.assertTrue(qat_model.linear1._fake_quant_enabled)
276-
self.assertTrue(qat_model.linear2._fake_quant_enabled)
277-
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
304+
assert_fake_quant_enabled(qat_model.linear1, enabled=True)
305+
assert_fake_quant_enabled(qat_model.linear2, enabled=True)
306+
assert_fake_quant_enabled(qat_model.sub.linear, enabled=True)
278307

279308
# Fake quant should be applied as normal
280309
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
281310
qat_model2 = quantizer2.prepare(m3)
282-
qat_model2.linear1.weight = qat_model.linear1.weight
283-
qat_model2.linear2.weight = qat_model.linear2.weight
284-
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
311+
qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor
312+
qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor
313+
qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor
285314
torch.manual_seed(self.SEED)
286315
x = m.example_inputs()
287316
x2 = copy.deepcopy(x)
@@ -306,9 +335,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
306335
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
307336
qat_model = quantizer.prepare(m)
308337
qat_model.apply(disable_8da4w_fake_quant)
309-
nn_model.linear1.weight = qat_model.linear1.weight
310-
nn_model.linear2.weight = qat_model.linear2.weight
311-
nn_model.sub.linear.weight = qat_model.sub.linear.weight
338+
self._copy_subclass_weights(nn_model.linear1, qat_model.linear1)
339+
self._copy_subclass_weights(nn_model.linear2, qat_model.linear2)
340+
self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)
312341

313342
# Simulate training for both models
314343
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
@@ -330,9 +359,55 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
330359
optimizer2.step()
331360

332361
# After 1 training step, weights should match exactly
333-
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
334-
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
335-
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
362+
self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1)
363+
self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2)
364+
self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)
365+
366+
def _test_qat_quantized_gradients(self, quantizer):
367+
"""
368+
Test that QAT produces gradients in the backward pass.
369+
"""
370+
num_steps = 10
371+
torch.manual_seed(self.SEED)
372+
m = M()
373+
model = quantizer.prepare(m)
374+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
375+
loss_fn = torch.nn.CrossEntropyLoss()
376+
377+
# Simulate training
378+
current_step = 0
379+
last_linear1_grad = None
380+
last_linear2_grad = None
381+
last_sub_linear_grad = None
382+
while current_step < num_steps:
383+
example_inputs = model.example_inputs()
384+
target = torch.randn(1, 512).float()
385+
output = model(*example_inputs)
386+
loss = loss_fn(output, target)
387+
loss.backward()
388+
# assert each linear grad is updated
389+
new_linear1_grad = model.linear1.weight.grad
390+
new_linear2_grad = model.linear2.weight.grad
391+
new_sub_linear_grad = model.sub.linear.weight.grad
392+
self.assertIsNotNone(new_linear1_grad)
393+
self.assertIsNotNone(new_linear2_grad)
394+
self.assertIsNotNone(new_sub_linear_grad)
395+
if current_step > 0:
396+
self.assertFalse(torch.equal(last_linear1_grad, new_linear1_grad))
397+
self.assertFalse(torch.equal(last_linear2_grad, new_linear2_grad))
398+
self.assertFalse(torch.equal(last_sub_linear_grad, new_sub_linear_grad))
399+
last_linear1_grad = new_linear1_grad
400+
last_linear2_grad = new_linear2_grad
401+
last_sub_linear_grad = new_sub_linear_grad
402+
optimizer.zero_grad()
403+
optimizer.step()
404+
current_step += 1
405+
406+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
407+
def test_qat_8da4w_quantizer_gradients(self):
408+
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
409+
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
410+
self._test_qat_quantized_gradients(quantizer)
336411

337412
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
338413
def test_qat_generic_fake_quantize(self):
@@ -353,7 +428,7 @@ def test_qat_generic_fake_quantize(self):
353428
block_size = (1, ao_input.shape[-1])
354429
ao_s = copy.deepcopy(py_s)
355430
ao_zp = copy.deepcopy(py_zp)
356-
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax, block_size)
431+
ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax)
357432
ao_out.sum().backward()
358433

359434
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
@@ -373,10 +448,7 @@ def _assert_close_4w(self, val, ref):
373448
print(mean_err)
374449
self.assertTrue(mean_err < 0.05)
375450

376-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
377451
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
378-
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
379-
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
380452
def test_qat_4w_primitives(self):
381453
n_bit = 4
382454
group_size = 32
@@ -464,11 +536,9 @@ def test_qat_4w_quantizer(self):
464536
qat_quantizer = Int4WeightOnlyQATQuantizer(
465537
groupsize=group_size, inner_k_tiles=inner_k_tiles,
466538
)
467-
ptq_quantizer = Int4WeightOnlyQuantizer(
468-
groupsize=group_size, inner_k_tiles=inner_k_tiles,
469-
)
470539
qat_model = qat_quantizer.prepare(m)
471-
ptq_model = ptq_quantizer.quantize(m2)
540+
ptq_model = m2
541+
quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles)))
472542

473543
# Compare model values
474544
torch.manual_seed(self.SEED)
@@ -483,12 +553,11 @@ def test_qat_4w_quantizer(self):
483553
converted_out = converted_model(*x)
484554
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)
485555

486-
# Compare converted state dict
487-
ptq_state_dict = ptq_model.state_dict()
488-
converted_state_dict = converted_model.state_dict()
489-
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
490-
for k in ptq_state_dict.keys():
491-
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
556+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
557+
def test_qat_4w_quantizer_gradients(self):
558+
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
559+
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
560+
self._test_qat_quantized_gradients(quantizer)
492561

493562

494563
if __name__ == "__main__":

torchao/quantization/prototype/qat/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
disable_8da4w_fake_quant,
44
enable_4w_fake_quant,
55
enable_8da4w_fake_quant,
6+
int4_weight_only_fake_quantize,
7+
int8_dynamic_activation_int4_weight_fake_quantize,
68
Int4WeightOnlyQATQuantizer,
79
Int8DynActInt4WeightQATQuantizer,
810
Int8DynActInt4WeightQATLinear,
@@ -13,6 +15,8 @@
1315
"disable_8da4w_fake_quant",
1416
"enable_4w_fake_quant",
1517
"enable_8da4w_fake_quant",
18+
"int4_weight_only_fake_quantize",
19+
"int8_dynamic_activation_int4_weight_fake_quantize",
1620
"Int4WeightOnlyQATQuantizer",
1721
"Int8DynActInt4WeightQATQuantizer",
1822
"Int8DynActInt4WeightQATLinear",

0 commit comments

Comments
 (0)