Skip to content

Commit 5704b30

Browse files
committed
Add an option to not quantize embedding layer when doing quantization.
This helps in getting better quality for small models (gemma 2b) etc.
1 parent 3afdc43 commit 5704b30

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

jetstream_pt/quantize_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from absl import flags
23
from .layers import (
34
create_quantized_from_nn_linear,
45
create_quantized_from_nn_embedding,
@@ -7,6 +8,13 @@
78
)
89

910

11+
_QUANTIZE_EMBEDDING = flags.DEFINE_bool(
12+
"internal_quantize_embedding_layer",
13+
True,
14+
"Whether to quantize embedding layer or not. Defaults to true",
15+
)
16+
17+
1018
def quantize_model(float_model, config):
1119
"""Apply quantization to linear layers."""
1220

@@ -17,7 +25,7 @@ def quantize_nn_mod(float_model):
1725
new_mod = mod.get_quantized_version()
1826
elif isinstance(mod, torch.nn.Linear):
1927
new_mod = create_quantized_from_nn_linear(mod, config)
20-
elif isinstance(mod, torch.nn.Embedding):
28+
elif isinstance(mod, torch.nn.Embedding) and _QUANTIZE_EMBEDDING.value:
2129
new_mod = create_quantized_from_nn_embedding(mod, config)
2230

2331
if new_mod:

jetstream_pt/third_party/gemma/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def forward(
437437
hidden_states = self.norm(hidden_states)
438438

439439
embedder_weight = self.embedder.weight
440-
if self.env.quant_config.enable_weight_quantization:
440+
if hasattr(self.embedder, "weight_scaler"):
441441
embedder_weight = embedder_weight * self.embedder.weight_scaler
442442
logits = torch.matmul(hidden_states, embedder_weight.t())
443443
return logits

0 commit comments

Comments
 (0)