Skip to content

Commit 1cc1b42

Browse files
committed
Add int4 weight-only embedding QAT
Based on changes in D62664322 by Tijmen Blankevoort. TODO: - add convert path - add tests
1 parent fbe97a0 commit 1cc1b42

File tree

3 files changed

+176
-23
lines changed

3 files changed

+176
-23
lines changed

torchao/quantization/prototype/qat/_module_swap_api.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,23 @@
2828
_choose_qparams_per_token_asymmetric,
2929
_fake_quantize_per_channel_group,
3030
_fake_quantize_per_token,
31+
_get_qmin_qmax,
3132
)
3233

3334

34-
# TODO: deprecate this flow in favor of the tensor subclass flow under qat/api.py
35-
# This is currently needed for DDP and FSDP1, which are not compatible with the
36-
# subclass flow.
35+
# TODO: make module swap the main flow again, and remove the quantize_ flow
36+
# TODO: rename this file to linear.py
37+
38+
# =========================================================
39+
# | Linear int8 dynamic activations + int4 weight QAT |
40+
# =========================================================
3741

3842

3943
class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantizer):
4044
"""
4145
Quantizer for performing QAT on a model, where linear layers have int8
4246
dynamic per token fake quantized activations and int4 fake quantized
4347
grouped per channel weights.
44-
45-
Note: This quantizer is implemented using module swaps and may be
46-
deprecated in the future. Please use `Int8DynActInt4WeightQATQuantizer`
47-
instead if possible.
4848
"""
4949

5050
def prepare(
@@ -92,7 +92,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
9292

9393
# Load weights and qparams into quantized linear
9494
n_bit = 4
95-
(qmin, qmax) = child._get_qmin_qmax(n_bit)
95+
(qmin, qmax) = _get_qmin_qmax(n_bit)
9696
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
9797
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
9898
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
@@ -156,7 +156,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
156156
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
157157
x, self.scales_precision, self.zero_points_precision,
158158
)
159-
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
159+
(act_qmin, act_qmax) = _get_qmin_qmax(8)
160160
x_fq = _fake_quantize_per_token(
161161
x, act_scales, act_zp, act_qmin, act_qmax,
162162
)
@@ -170,7 +170,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
170170
)
171171
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
172172
weight_zp = weight_zp.to(self.zero_points_precision)
173-
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
173+
(weight_qmin, weight_qmax) = _get_qmin_qmax(4)
174174
w_fq = _fake_quantize_per_channel_group(
175175
self.weight,
176176
weight_scales,
@@ -183,12 +183,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
183183
w_fq = self.weight
184184
return F.linear(x_fq, w_fq)
185185

186-
# TODO: move this to common util
187-
def _get_qmin_qmax(self, n_bit: int):
188-
qmin = -(2 ** (n_bit - 1))
189-
qmax = 2 ** (n_bit - 1) - 1
190-
return (qmin, qmax)
191-
192186

193187
def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):
194188
"""
@@ -206,19 +200,15 @@ def disable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):
206200
mod.disable_fake_quant()
207201

208202

209-
# ==================
210-
# | int4wo QAT |
211-
# ==================
203+
# ===================================
204+
# | Linear int4 weight-only QAT |
205+
# ===================================
212206

213207

214208
class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer):
215209
"""
216210
Quantizer for performing QAT on a model, where linear layers have
217211
int4 fake quantized grouped per channel weights.
218-
219-
Note: This quantizer is implemented using module swaps and may be
220-
deprecated in the future. Please use `Int4WeightOnlyQATQuantizer`
221-
instead if possible.
222212
"""
223213

224214
def prepare(
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any
8+
9+
import torch
10+
import torch.nn.functional as F
11+
12+
from torchao.quantization.unified import TwoStepQuantizer
13+
from torchao.quantization.utils import get_group_qparams_symmetric
14+
from torchao.quantization.quant_api import (
15+
_replace_with_custom_fn_if_matches_filter,
16+
)
17+
from .utils import (
18+
_fake_quantize_per_channel_group,
19+
_get_qmin_qmax,
20+
)
21+
22+
23+
# ======================================
24+
# | Embedding int4 weight-only QAT |
25+
# ======================================
26+
27+
class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer):
28+
"""
29+
Quantizer for performing QAT on a model, where embedding layers have
30+
int4 fake quantized grouped per channel weights.
31+
"""
32+
33+
def __init__(
34+
self,
35+
group_size: int = 256,
36+
scale_precision: torch.dtype = torch.float32,
37+
zero_point_precision: torch.dtype = torch.int32,
38+
) -> None:
39+
super().__init__()
40+
self.group_size: int = group_size
41+
self.scale_precision: torch.dtype = scale_precision
42+
self.zero_point_precision: torch.dtype = zero_point_precision,
43+
44+
def prepare(
45+
self,
46+
model: torch.nn.Module,
47+
*args: Any,
48+
**kwargs: Any
49+
) -> torch.nn.Module:
50+
"""
51+
Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`.
52+
"""
53+
def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
54+
return isinstance(child, nn.Embedding)
55+
56+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
57+
new_embedding = Int4WeightOnlyQATEmbedding(
58+
group_size=self.group_size,
59+
60+
# other nn.Embedding args
61+
num_embeddings=child.num_embeddings,
62+
embedding_dim=child.embedding_dim,
63+
padding_idx=child.padding_idx,
64+
max_norm=child.max_norm,
65+
norm_type=child.norm_type,
66+
scale_grad_by_freq=child.scale_grad_by_freq,
67+
sparse=child.sparse,
68+
device=child.weight.device,
69+
)
70+
# In distributed training, the model may be instantiated
71+
# on the meta device, in which case there is no need to
72+
# copy the weights, and doing so will result in an error
73+
if child.weight.device != torch.device("meta"):
74+
new_embedding.weight = child.weight
75+
return new_embedding
76+
77+
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn)
78+
return model
79+
80+
def convert(
81+
self,
82+
model: torch.nn.Module,
83+
*args: Any,
84+
**kwargs: Any
85+
) -> torch.nn.Module:
86+
"""
87+
Swap `Int4WeightOnlyQATEmbedding` with `Int4WeightOnlyEmbedding`
88+
"""
89+
# TODO: implement this
90+
print("Warning: int4 weight-only embedding convert flow not implemented yet")
91+
return model
92+
93+
94+
class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
95+
"""
96+
This module implements a embedding layer with int4 fake quantized
97+
grouped per channel weights.
98+
99+
args:
100+
group_size: the number of elements in each quantized group for weights
101+
scale_precision: precision of per group scales
102+
zero_point_precision: precision of per group zero points
103+
"""
104+
105+
def __init__(
106+
self,
107+
group_size: int = 32,
108+
scale_precision: torch.dtype = torch.float32,
109+
zero_point_precision: torch.dtype = torch.int32,
110+
*args,
111+
**kwargs,
112+
):
113+
super().__init__(*args, **kwargs)
114+
self.bit_width = 4
115+
self.group_size = group_size
116+
self.scale_precision = scale_precision
117+
self.zero_point_precision = zero_point_precision
118+
self._fake_quant_enabled = True
119+
120+
def forward(self, x):
121+
weight = self.weight
122+
123+
if self._fake_quant_enabled:
124+
(weight_scales, weight_zp) = get_group_qparams_symmetric(
125+
self.weight, self.bit_width, self.group_size, self.scale_precision,
126+
)
127+
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
128+
weight_zp = weight_zp.to(self.zero_point_precision)
129+
(weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width)
130+
w_fq = _fake_quantize_per_channel_group(
131+
self.weight,
132+
weight_scales,
133+
weight_zp,
134+
weight_qmin,
135+
weight_qmax,
136+
self.group_size,
137+
)
138+
else:
139+
w_fq = self.weight
140+
141+
return F.embedding(
142+
x, w_fq, self.padding_idx, self.max_norm,
143+
self.norm_type, self.scale_grad_by_freq, self.sparse,
144+
)
145+
146+
def enable_fake_quant(self, enabled: bool = True):
147+
self._fake_quant_enabled = enabled
148+
149+
def disable_fake_quant(self):
150+
self.enable_fake_quant(False)
151+
152+
153+
class Int4WeightOnlyEmbedding(torch.nn.Embedding):
154+
"""
155+
This module implements a embedding layer with int4 quantized
156+
grouped per channel weights.
157+
"""
158+
pass

torchao/quantization/prototype/qat/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,8 @@ def insert_subclass(lin):
259259
return lin
260260

261261
return insert_subclass
262+
263+
def _get_qmin_qmax(n_bit: int):
264+
qmin = -(2 ** (n_bit - 1))
265+
qmax = 2 ** (n_bit - 1) - 1
266+
return (qmin, qmax)

0 commit comments

Comments
 (0)