1313    _check_linear_int4_k ,
1414    _replace_linear_int4 ,
1515    _replace_linear_8da4w ,
16+     _replace_embedding_4w ,
1617    get_groupwise_affine_qparams ,
1718    groupwise_affine_quantize_tensor ,
1819    Int8DynActInt4WeightLinear ,
2829    _choose_qparams_per_token_asymmetric ,
2930    _fake_quantize_per_channel_group ,
3031    _fake_quantize_per_token ,
32+     _get_qmin_qmax 
3133)
3234
3335
@@ -47,6 +49,14 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize
4749    instead if possible. 
4850    """ 
4951
52+     def  __init__ (self , 
53+                  quantize_embedding : bool  =  False ,
54+                  embedding_groupsize : int  =  32 ,
55+                  * args , ** kwargs ):
56+         super ().__init__ (* args , ** kwargs )
57+         self .quantize_embedding  =  quantize_embedding 
58+         self .embedding_groupsize  =  embedding_groupsize 
59+ 
5060    def  prepare (
5161        self ,
5262        model : torch .nn .Module ,
@@ -62,6 +72,14 @@ def prepare(
6272            Int8DynActInt4WeightQATLinear ,
6373            copy_weights = True ,
6474        )
75+         if  self .quantize_embedding :
76+             _replace_embedding_4w (
77+                 model ,
78+                 self .embedding_groupsize ,
79+                 Int4WeightQATEmbedding ,
80+                 self .padding_allowed ,
81+                 copy_weights = True 
82+             )
6583        return  model 
6684
6785    def  convert (
@@ -92,7 +110,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
92110
93111            # Load weights and qparams into quantized linear 
94112            n_bit  =  4 
95-             (qmin , qmax ) =  child . _get_qmin_qmax (n_bit )
113+             (qmin , qmax ) =  _get_qmin_qmax (n_bit )
96114            (s , zp ) =  get_group_qparams_symmetric (child .weight , n_bit , child .groupsize )
97115            from  torchao ._executorch_ops  import  _quantized_decomposed_quantize_per_channel_group_wrapper 
98116            q_weight  =  _quantized_decomposed_quantize_per_channel_group_wrapper (
@@ -150,13 +168,14 @@ def enable_fake_quant(self, enabled: bool = True):
150168    def  disable_fake_quant (self ):
151169        self .enable_fake_quant (False )
152170
171+     # pyre-ignore[14]: inconsistent override 
153172    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
154173        # activations: int8 dynamic asymmetric quant 
155174        if  self ._fake_quant_enabled :
156175            (act_scales , act_zp ) =  _choose_qparams_per_token_asymmetric (
157176                x , self .scales_precision , self .zero_points_precision ,
158177            )
159-             (act_qmin , act_qmax ) =  self . _get_qmin_qmax (8 )
178+             (act_qmin , act_qmax ) =  _get_qmin_qmax (8 )
160179            x_fq  =  _fake_quantize_per_token (
161180                x , act_scales , act_zp , act_qmin , act_qmax ,
162181            )
@@ -170,7 +189,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
170189            )
171190            # TODO: pass zp dtype to `get_group_qparams_symmetric` instead 
172191            weight_zp  =  weight_zp .to (self .zero_points_precision )
173-             (weight_qmin , weight_qmax ) =  self . _get_qmin_qmax (4 )
192+             (weight_qmin , weight_qmax ) =  _get_qmin_qmax (4 )
174193            w_fq  =  _fake_quantize_per_channel_group (
175194                self .weight ,
176195                weight_scales ,
@@ -183,11 +202,58 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
183202            w_fq  =  self .weight 
184203        return  F .linear (x_fq , w_fq )
185204
186-     # TODO: move this to common util 
187-     def  _get_qmin_qmax (self , n_bit : int ):
188-         qmin  =  - (2  **  (n_bit  -  1 ))
189-         qmax  =  2  **  (n_bit  -  1 ) -  1 
190-         return  (qmin , qmax )
205+ 
206+ class  Int4WeightQATEmbedding (torch .nn .Embedding ):
207+     """ 
208+     This module implements a embedding layer with int4 
209+ 
210+     args: 
211+         embedding_groupsize: the number of elements in each quantized group for weights 
212+         scales_precision: precision of per group scales and zero points 
213+     """ 
214+ 
215+     def  __init__ (self , 
216+                 groupsize : int  =  32 , 
217+                 scales_precision : torch .dtype  =  torch .float32 ,
218+                 * args , 
219+                 ** kwargs ):
220+         super ().__init__ (* args , ** kwargs )
221+         self .bit_width  =  4 
222+         self .groupsize  =  groupsize 
223+         self .scales_precision  =  scales_precision 
224+         self .zero_points_precision  =  torch .int32 
225+         self ._fake_quant_enabled  =  True 
226+     
227+     def  forward (self , x ):
228+         weight  =  self .weight 
229+ 
230+         if  self ._fake_quant_enabled :
231+             (weight_scales , weight_zp ) =  get_group_qparams_symmetric (
232+                 self .weight , self .bit_width , self .groupsize , self .scales_precision ,
233+             )
234+             # TODO: pass zp dtype to `get_group_qparams_symmetric` instead 
235+             weight_zp  =  weight_zp .to (self .zero_points_precision )
236+             (weight_qmin , weight_qmax ) =  _get_qmin_qmax (self .bit_width )
237+             w_fq  =  _fake_quantize_per_channel_group (
238+                 self .weight ,
239+                 weight_scales ,
240+                 weight_zp ,
241+                 weight_qmin ,
242+                 weight_qmax ,
243+                 self .groupsize ,
244+             )
245+         else :
246+             w_fq  =  self .weight 
247+ 
248+         return  torch .nn .functional .embedding (
249+             x , w_fq , self .padding_idx , self .max_norm ,
250+             self .norm_type , self .scale_grad_by_freq , self .sparse )
251+ 
252+     def  enable_fake_quant (self , enabled : bool  =  True ):
253+         self ._fake_quant_enabled  =  enabled 
254+ 
255+     def  disable_fake_quant (self ):
256+         self .enable_fake_quant (False )
191257
192258
193259def  enable_8da4w_fake_quant_module_swap (mod : torch .nn .Module ):
0 commit comments