File tree 4 files changed +15
-14
lines changed 4 files changed +15
-14
lines changed Original file line number Diff line number Diff line change 34
34
"if set, then save the result to the given file name" ,
35
35
)
36
36
flags .DEFINE_bool (
37
- "internal_use_local_tokenizer" ,
38
- 0 ,
39
- "Use local tokenizer if set to True"
37
+ "internal_use_local_tokenizer" , 0 , "Use local tokenizer if set to True"
40
38
)
41
39
40
+
42
41
def shard_weights (env , weights , weight_shardings ):
43
42
"""Shard weights according to weight_shardings"""
44
43
sharded = {}
Original file line number Diff line number Diff line change 49
49
flags .DEFINE_bool (
50
50
"quantize_kv_cache" , None , "defaults to the same value as quantize_weights"
51
51
)
52
+ flags .DEFINE_multi_string (
53
+ "quantize_exclude_layers" ,
54
+ None ,
55
+ "List of layer names to exclude from quantization" ,
56
+ )
52
57
53
58
_VALID_QUANTIZATION_TYPE = {
54
59
"int8_per_channel" ,
@@ -178,6 +183,7 @@ def create_quantization_config_from_flags():
178
183
config .is_blockwise_weight = "blockwise" in quantize_type
179
184
180
185
config .enable_activation_quantization = FLAGS .quantize_activation
186
+ config .exclude_layers = FLAGS .quantize_exclude_layers
181
187
config .enable_kv_quantization = (
182
188
FLAGS .quantize_kv_cache
183
189
if FLAGS .quantize_kv_cache is not None
Original file line number Diff line number Diff line change 13
13
# limitations under the License.
14
14
15
15
import dataclasses
16
- from typing import Tuple
16
+ from typing import List , Tuple , Union
17
17
18
18
import jax
19
19
import jax .numpy as jnp
@@ -37,6 +37,7 @@ class QuantizationConfig:
37
37
38
38
enable_activation_quantization : bool = False
39
39
enable_kv_quantization : bool = False
40
+ exclude_layers : Union [None , List [str ]] = None
40
41
41
42
42
43
@dataclasses .dataclass
Original file line number Diff line number Diff line change 1
1
import torch
2
- from absl import flags
2
+ from . environment import QuantizationConfig
3
3
from .layers import (
4
4
create_quantized_from_nn_linear ,
5
5
create_quantized_from_nn_embedding ,
8
8
)
9
9
10
10
11
- _QUANTIZE_EMBEDDING = flags .DEFINE_bool (
12
- "internal_quantize_embedding_layer" ,
13
- True ,
14
- "Whether to quantize embedding layer or not. Defaults to true" ,
15
- )
16
-
17
-
18
- def quantize_model (float_model , config ):
11
+ def quantize_model (float_model , config : QuantizationConfig ):
19
12
"""Apply quantization to linear layers."""
20
13
21
14
def quantize_nn_mod (float_model ):
22
15
for name , mod in float_model .named_modules ():
23
16
new_mod = None
17
+ if config .exclude_layers and name in config .exclude_layers :
18
+ continue
24
19
if hasattr (mod , "get_quantized_version" ):
25
20
new_mod = mod .get_quantized_version ()
26
21
elif isinstance (mod , torch .nn .Linear ):
27
22
new_mod = create_quantized_from_nn_linear (mod , config )
28
- elif isinstance (mod , torch .nn .Embedding ) and _QUANTIZE_EMBEDDING . value :
23
+ elif isinstance (mod , torch .nn .Embedding ):
29
24
new_mod = create_quantized_from_nn_embedding (mod , config )
30
25
31
26
if new_mod :
You can’t perform that action at this time.
0 commit comments