Skip to content

Commit 664c124

Browse files
authored
feat: add quantize exclude layer flag (#194)
* feat: add quantize exclude layer flag Instead of internal_quantize_embedding_layer, add a flag that will allow to specify a list of layers. This can be a flexible solution to avoid quantizing a given list of layers, the Embedding layer but even few more if required. * review: run ink to fix format
1 parent e8f5f00 commit 664c124

File tree

4 files changed

+15
-14
lines changed

4 files changed

+15
-14
lines changed

jetstream_pt/cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@
3434
"if set, then save the result to the given file name",
3535
)
3636
flags.DEFINE_bool(
37-
"internal_use_local_tokenizer",
38-
0,
39-
"Use local tokenizer if set to True"
37+
"internal_use_local_tokenizer", 0, "Use local tokenizer if set to True"
4038
)
4139

40+
4241
def shard_weights(env, weights, weight_shardings):
4342
"""Shard weights according to weight_shardings"""
4443
sharded = {}

jetstream_pt/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
flags.DEFINE_bool(
5050
"quantize_kv_cache", None, "defaults to the same value as quantize_weights"
5151
)
52+
flags.DEFINE_multi_string(
53+
"quantize_exclude_layers",
54+
None,
55+
"List of layer names to exclude from quantization",
56+
)
5257

5358
_VALID_QUANTIZATION_TYPE = {
5459
"int8_per_channel",
@@ -178,6 +183,7 @@ def create_quantization_config_from_flags():
178183
config.is_blockwise_weight = "blockwise" in quantize_type
179184

180185
config.enable_activation_quantization = FLAGS.quantize_activation
186+
config.exclude_layers = FLAGS.quantize_exclude_layers
181187
config.enable_kv_quantization = (
182188
FLAGS.quantize_kv_cache
183189
if FLAGS.quantize_kv_cache is not None

jetstream_pt/environment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16-
from typing import Tuple
16+
from typing import List, Tuple, Union
1717

1818
import jax
1919
import jax.numpy as jnp
@@ -37,6 +37,7 @@ class QuantizationConfig:
3737

3838
enable_activation_quantization: bool = False
3939
enable_kv_quantization: bool = False
40+
exclude_layers: Union[None, List[str]] = None
4041

4142

4243
@dataclasses.dataclass

jetstream_pt/quantize_model.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from absl import flags
2+
from .environment import QuantizationConfig
33
from .layers import (
44
create_quantized_from_nn_linear,
55
create_quantized_from_nn_embedding,
@@ -8,24 +8,19 @@
88
)
99

1010

11-
_QUANTIZE_EMBEDDING = flags.DEFINE_bool(
12-
"internal_quantize_embedding_layer",
13-
True,
14-
"Whether to quantize embedding layer or not. Defaults to true",
15-
)
16-
17-
18-
def quantize_model(float_model, config):
11+
def quantize_model(float_model, config: QuantizationConfig):
1912
"""Apply quantization to linear layers."""
2013

2114
def quantize_nn_mod(float_model):
2215
for name, mod in float_model.named_modules():
2316
new_mod = None
17+
if config.exclude_layers and name in config.exclude_layers:
18+
continue
2419
if hasattr(mod, "get_quantized_version"):
2520
new_mod = mod.get_quantized_version()
2621
elif isinstance(mod, torch.nn.Linear):
2722
new_mod = create_quantized_from_nn_linear(mod, config)
28-
elif isinstance(mod, torch.nn.Embedding) and _QUANTIZE_EMBEDDING.value:
23+
elif isinstance(mod, torch.nn.Embedding):
2924
new_mod = create_quantized_from_nn_embedding(mod, config)
3025

3126
if new_mod:

0 commit comments

Comments
 (0)