1313 _check_linear_int4_k ,
1414 _replace_linear_int4 ,
1515 _replace_linear_8da4w ,
16+ _replace_embedding_4w ,
1617 get_groupwise_affine_qparams ,
1718 groupwise_affine_quantize_tensor ,
1819 Int8DynActInt4WeightLinear ,
2829 _choose_qparams_per_token_asymmetric ,
2930 _fake_quantize_per_channel_group ,
3031 _fake_quantize_per_token ,
32+ _get_qmin_qmax
3133)
3234
3335
@@ -47,6 +49,14 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize
4749 instead if possible.
4850 """
4951
52+ def __init__ (self ,
53+ quantize_embedding : bool = False ,
54+ embedding_groupsize : int = 32 ,
55+ * args , ** kwargs ):
56+ super ().__init__ (* args , ** kwargs )
57+ self .quantize_embedding = quantize_embedding
58+ self .embedding_groupsize = embedding_groupsize
59+
5060 def prepare (
5161 self ,
5262 model : torch .nn .Module ,
@@ -62,6 +72,14 @@ def prepare(
6272 Int8DynActInt4WeightQATLinear ,
6373 copy_weights = True ,
6474 )
75+ if self .quantize_embedding :
76+ _replace_embedding_4w (
77+ model ,
78+ self .embedding_groupsize ,
79+ Int4WeightQATEmbedding ,
80+ self .padding_allowed ,
81+ copy_weights = True
82+ )
6583 return model
6684
6785 def convert (
@@ -92,7 +110,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
92110
93111 # Load weights and qparams into quantized linear
94112 n_bit = 4
95- (qmin , qmax ) = child . _get_qmin_qmax (n_bit )
113+ (qmin , qmax ) = _get_qmin_qmax (n_bit )
96114 (s , zp ) = get_group_qparams_symmetric (child .weight , n_bit , child .groupsize )
97115 from torchao ._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
98116 q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
@@ -150,13 +168,14 @@ def enable_fake_quant(self, enabled: bool = True):
150168 def disable_fake_quant (self ):
151169 self .enable_fake_quant (False )
152170
171+ # pyre-ignore[14]: inconsistent override
153172 def forward (self , x : torch .Tensor ) -> torch .Tensor :
154173 # activations: int8 dynamic asymmetric quant
155174 if self ._fake_quant_enabled :
156175 (act_scales , act_zp ) = _choose_qparams_per_token_asymmetric (
157176 x , self .scales_precision , self .zero_points_precision ,
158177 )
159- (act_qmin , act_qmax ) = self . _get_qmin_qmax (8 )
178+ (act_qmin , act_qmax ) = _get_qmin_qmax (8 )
160179 x_fq = _fake_quantize_per_token (
161180 x , act_scales , act_zp , act_qmin , act_qmax ,
162181 )
@@ -170,7 +189,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
170189 )
171190 # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
172191 weight_zp = weight_zp .to (self .zero_points_precision )
173- (weight_qmin , weight_qmax ) = self . _get_qmin_qmax (4 )
192+ (weight_qmin , weight_qmax ) = _get_qmin_qmax (4 )
174193 w_fq = _fake_quantize_per_channel_group (
175194 self .weight ,
176195 weight_scales ,
@@ -183,11 +202,58 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
183202 w_fq = self .weight
184203 return F .linear (x_fq , w_fq )
185204
186- # TODO: move this to common util
187- def _get_qmin_qmax (self , n_bit : int ):
188- qmin = - (2 ** (n_bit - 1 ))
189- qmax = 2 ** (n_bit - 1 ) - 1
190- return (qmin , qmax )
205+
206+ class Int4WeightQATEmbedding (torch .nn .Embedding ):
207+ """
208+ This module implements a embedding layer with int4
209+
210+ args:
211+ embedding_groupsize: the number of elements in each quantized group for weights
212+ scales_precision: precision of per group scales and zero points
213+ """
214+
215+ def __init__ (self ,
216+ groupsize : int = 32 ,
217+ scales_precision : torch .dtype = torch .float32 ,
218+ * args ,
219+ ** kwargs ):
220+ super ().__init__ (* args , ** kwargs )
221+ self .bit_width = 4
222+ self .groupsize = groupsize
223+ self .scales_precision = scales_precision
224+ self .zero_points_precision = torch .int32
225+ self ._fake_quant_enabled = True
226+
227+ def forward (self , x ):
228+ weight = self .weight
229+
230+ if self ._fake_quant_enabled :
231+ (weight_scales , weight_zp ) = get_group_qparams_symmetric (
232+ self .weight , self .bit_width , self .groupsize , self .scales_precision ,
233+ )
234+ # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
235+ weight_zp = weight_zp .to (self .zero_points_precision )
236+ (weight_qmin , weight_qmax ) = _get_qmin_qmax (self .bit_width )
237+ w_fq = _fake_quantize_per_channel_group (
238+ self .weight ,
239+ weight_scales ,
240+ weight_zp ,
241+ weight_qmin ,
242+ weight_qmax ,
243+ self .groupsize ,
244+ )
245+ else :
246+ w_fq = self .weight
247+
248+ return torch .nn .functional .embedding (
249+ x , w_fq , self .padding_idx , self .max_norm ,
250+ self .norm_type , self .scale_grad_by_freq , self .sparse )
251+
252+ def enable_fake_quant (self , enabled : bool = True ):
253+ self ._fake_quant_enabled = enabled
254+
255+ def disable_fake_quant (self ):
256+ self .enable_fake_quant (False )
191257
192258
193259def enable_8da4w_fake_quant_module_swap (mod : torch .nn .Module ):
0 commit comments