Skip to content

Commit 02927c9

Browse files
authoredOct 23, 2024··
Fix: correct quantization name filtering (#196)
fix: correct quantization name filtering The quantization filter based on layer names did not work, because modules walk is done with the Module.apply method, that resolves names locally, so the "absolute" naming does not work. The fix just prepares a list out of the names before entering the loop, so the correct reference is captured.
1 parent 664c124 commit 02927c9

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed
 

‎jetstream_pt/quantize_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@
1010

1111
def quantize_model(float_model, config: QuantizationConfig):
1212
"""Apply quantization to linear layers."""
13+
exclude_mods = None
14+
if config.exclude_layers:
15+
exclude_mods = [
16+
module
17+
for name, module in float_model.named_modules()
18+
if name in config.exclude_layers
19+
]
1320

1421
def quantize_nn_mod(float_model):
1522
for name, mod in float_model.named_modules():
1623
new_mod = None
17-
if config.exclude_layers and name in config.exclude_layers:
24+
if config.exclude_layers and mod in exclude_mods:
1825
continue
1926
if hasattr(mod, "get_quantized_version"):
2027
new_mod = mod.get_quantized_version()

0 commit comments

Comments
 (0)
Please sign in to comment.