Skip to content

Commit 8cd31b8

Browse files
Merge branch 'main' into sixiang
2 parents 15c9a48 + 5704b30 commit 8cd31b8

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

jetstream_pt/quantize_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from absl import flags
23
from .layers import (
34
create_quantized_from_nn_linear,
45
create_quantized_from_nn_embedding,
@@ -7,6 +8,13 @@
78
)
89

910

11+
_QUANTIZE_EMBEDDING = flags.DEFINE_bool(
12+
"internal_quantize_embedding_layer",
13+
True,
14+
"Whether to quantize embedding layer or not. Defaults to true",
15+
)
16+
17+
1018
def quantize_model(float_model, config):
1119
"""Apply quantization to linear layers."""
1220

@@ -17,7 +25,7 @@ def quantize_nn_mod(float_model):
1725
new_mod = mod.get_quantized_version()
1826
elif isinstance(mod, torch.nn.Linear):
1927
new_mod = create_quantized_from_nn_linear(mod, config)
20-
elif isinstance(mod, torch.nn.Embedding):
28+
elif isinstance(mod, torch.nn.Embedding) and _QUANTIZE_EMBEDDING.value:
2129
new_mod = create_quantized_from_nn_embedding(mod, config)
2230

2331
if new_mod:

jetstream_pt/third_party/gemma/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def forward(
437437
hidden_states = self.norm(hidden_states)
438438

439439
embedder_weight = self.embedder.weight
440-
if self.env.quant_config.enable_weight_quantization:
440+
if hasattr(self.embedder, "weight_scaler"):
441441
embedder_weight = embedder_weight * self.embedder.weight_scaler
442442
logits = torch.matmul(hidden_states, embedder_weight.t())
443443
return logits

0 commit comments

Comments
 (0)