Skip to content

Commit 68df007

Browse files
TiRunefacebook-github-bot
authored andcommitted
Add Embedding Quantization to QAT module_swap flow (#886)
Summary: Pull Request resolved: #886 Adding the embedding quantizer in the same fashion as the other module swap setup. Differential Revision: D62664322
1 parent 72d2518 commit 68df007

File tree

3 files changed

+114
-8
lines changed

3 files changed

+114
-8
lines changed

torchao/quantization/GPTQ.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,41 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
965965
self.precision,
966966
)
967967

968+
969+
def _replace_embedding_4w(
970+
module: torch.nn.Module,
971+
groupsize: int,
972+
embedding_class: Type[torch.nn.Module],
973+
padding_allowed: bool,
974+
copy_weights: bool = False,
975+
):
976+
#import the util function here to avoid circular dependency
977+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
978+
979+
def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
980+
return isinstance(child, nn.Embedding) and (_check_linear_int4_k(child.embedding_dim, groupsize) or padding_allowed)
981+
982+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
983+
new_embedding = embedding_class(
984+
num_embeddings = child.num_embeddings,
985+
embedding_dim = child.embedding_dim,
986+
padding_idx = child.padding_idx,
987+
max_norm = child.max_norm,
988+
norm_type = child.norm_type,
989+
scale_grad_by_freq = child.scale_grad_by_freq,
990+
sparse = child.sparse,
991+
device=child.weight.device,
992+
groupsize=groupsize,
993+
)
994+
# In distributed training, the model may be instantiated
995+
# on the meta device, in which case there is no need to
996+
# copy the weights, and doing so will result in an error
997+
if copy_weights and child.weight.device != torch.device("meta"):
998+
new_embedding.weight = child.weight
999+
return new_embedding
1000+
1001+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
1002+
9681003
def _replace_linear_8da4w(
9691004
module: torch.nn.Module,
9701005
groupsize: int,

torchao/quantization/prototype/qat/_module_swap_api.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_check_linear_int4_k,
1414
_replace_linear_int4,
1515
_replace_linear_8da4w,
16+
_replace_embedding_4w,
1617
get_groupwise_affine_qparams,
1718
groupwise_affine_quantize_tensor,
1819
Int8DynActInt4WeightLinear,
@@ -28,6 +29,7 @@
2829
_choose_qparams_per_token_asymmetric,
2930
_fake_quantize_per_channel_group,
3031
_fake_quantize_per_token,
32+
_get_qmin_qmax
3133
)
3234

3335

@@ -47,6 +49,14 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize
4749
instead if possible.
4850
"""
4951

52+
def __init__(self,
53+
quantize_embedding: bool = False,
54+
embedding_groupsize: int = 32,
55+
*args, **kwargs):
56+
super().__init__(*args, **kwargs)
57+
self.quantize_embedding = quantize_embedding
58+
self.embedding_groupsize = embedding_groupsize
59+
5060
def prepare(
5161
self,
5262
model: torch.nn.Module,
@@ -62,6 +72,14 @@ def prepare(
6272
Int8DynActInt4WeightQATLinear,
6373
copy_weights=True,
6474
)
75+
if self.quantize_embedding:
76+
_replace_embedding_4w(
77+
model,
78+
self.embedding_groupsize,
79+
Int4WeightQATEmbedding,
80+
self.padding_allowed,
81+
copy_weights=True
82+
)
6583
return model
6684

6785
def convert(
@@ -92,7 +110,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
92110

93111
# Load weights and qparams into quantized linear
94112
n_bit = 4
95-
(qmin, qmax) = child._get_qmin_qmax(n_bit)
113+
(qmin, qmax) = _get_qmin_qmax(n_bit)
96114
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
97115
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
98116
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
@@ -150,13 +168,14 @@ def enable_fake_quant(self, enabled: bool = True):
150168
def disable_fake_quant(self):
151169
self.enable_fake_quant(False)
152170

171+
# pyre-ignore[14]: inconsistent override
153172
def forward(self, x: torch.Tensor) -> torch.Tensor:
154173
# activations: int8 dynamic asymmetric quant
155174
if self._fake_quant_enabled:
156175
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
157176
x, self.scales_precision, self.zero_points_precision,
158177
)
159-
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
178+
(act_qmin, act_qmax) = _get_qmin_qmax(8)
160179
x_fq = _fake_quantize_per_token(
161180
x, act_scales, act_zp, act_qmin, act_qmax,
162181
)
@@ -170,7 +189,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
170189
)
171190
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
172191
weight_zp = weight_zp.to(self.zero_points_precision)
173-
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
192+
(weight_qmin, weight_qmax) = _get_qmin_qmax(4)
174193
w_fq = _fake_quantize_per_channel_group(
175194
self.weight,
176195
weight_scales,
@@ -183,11 +202,58 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
183202
w_fq = self.weight
184203
return F.linear(x_fq, w_fq)
185204

186-
# TODO: move this to common util
187-
def _get_qmin_qmax(self, n_bit: int):
188-
qmin = -(2 ** (n_bit - 1))
189-
qmax = 2 ** (n_bit - 1) - 1
190-
return (qmin, qmax)
205+
206+
class Int4WeightQATEmbedding(torch.nn.Embedding):
207+
"""
208+
This module implements a embedding layer with int4
209+
210+
args:
211+
embedding_groupsize: the number of elements in each quantized group for weights
212+
scales_precision: precision of per group scales and zero points
213+
"""
214+
215+
def __init__(self,
216+
groupsize: int = 32,
217+
scales_precision: torch.dtype = torch.float32,
218+
*args,
219+
**kwargs):
220+
super().__init__(*args, **kwargs)
221+
self.bit_width = 4
222+
self.groupsize = groupsize
223+
self.scales_precision = scales_precision
224+
self.zero_points_precision = torch.int32
225+
self._fake_quant_enabled = True
226+
227+
def forward(self, x):
228+
weight = self.weight
229+
230+
if self._fake_quant_enabled:
231+
(weight_scales, weight_zp) = get_group_qparams_symmetric(
232+
self.weight, self.bit_width, self.groupsize, self.scales_precision,
233+
)
234+
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
235+
weight_zp = weight_zp.to(self.zero_points_precision)
236+
(weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width)
237+
w_fq = _fake_quantize_per_channel_group(
238+
self.weight,
239+
weight_scales,
240+
weight_zp,
241+
weight_qmin,
242+
weight_qmax,
243+
self.groupsize,
244+
)
245+
else:
246+
w_fq = self.weight
247+
248+
return torch.nn.functional.embedding(
249+
x, w_fq, self.padding_idx, self.max_norm,
250+
self.norm_type, self.scale_grad_by_freq, self.sparse)
251+
252+
def enable_fake_quant(self, enabled: bool = True):
253+
self._fake_quant_enabled = enabled
254+
255+
def disable_fake_quant(self):
256+
self.enable_fake_quant(False)
191257

192258

193259
def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):

torchao/quantization/prototype/qat/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,8 @@ def insert_subclass(lin):
259259
return lin
260260

261261
return insert_subclass
262+
263+
def _get_qmin_qmax(n_bit: int):
264+
qmin = -(2 ** (n_bit - 1))
265+
qmax = 2 ** (n_bit - 1) - 1
266+
return (qmin, qmax)

0 commit comments

Comments
 (0)