File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change 1
1
import torch
2
+ from absl import flags
2
3
from .layers import (
3
4
create_quantized_from_nn_linear ,
4
5
create_quantized_from_nn_embedding ,
7
8
)
8
9
9
10
11
+ _QUANTIZE_EMBEDDING = flags .DEFINE_bool (
12
+ "internal_quantize_embedding_layer" ,
13
+ True ,
14
+ "Whether to quantize embedding layer or not. Defaults to true" ,
15
+ )
16
+
17
+
10
18
def quantize_model (float_model , config ):
11
19
"""Apply quantization to linear layers."""
12
20
@@ -17,7 +25,7 @@ def quantize_nn_mod(float_model):
17
25
new_mod = mod .get_quantized_version ()
18
26
elif isinstance (mod , torch .nn .Linear ):
19
27
new_mod = create_quantized_from_nn_linear (mod , config )
20
- elif isinstance (mod , torch .nn .Embedding ):
28
+ elif isinstance (mod , torch .nn .Embedding ) and _QUANTIZE_EMBEDDING . value :
21
29
new_mod = create_quantized_from_nn_embedding (mod , config )
22
30
23
31
if new_mod :
Original file line number Diff line number Diff line change @@ -437,7 +437,7 @@ def forward(
437
437
hidden_states = self .norm (hidden_states )
438
438
439
439
embedder_weight = self .embedder .weight
440
- if self .env . quant_config . enable_weight_quantization :
440
+ if hasattr ( self .embedder , "weight_scaler" ) :
441
441
embedder_weight = embedder_weight * self .embedder .weight_scaler
442
442
logits = torch .matmul (hidden_states , embedder_weight .t ())
443
443
return logits
You can’t perform that action at this time.
0 commit comments