Skip to content

Commit 5a77b59

Browse files
authored
Keep quantization enabled during calibration (#1299)
## Purpose ## * Revert the behavior regression introduced as a result of #1114 * When calibrating a model using the `QuantizationModifier`, quantization should be enabled when calibrating ## Changes ## * Remove "disabling quantization" from the calibration forward pass * Add "disabling quantization" to the sequential pipelines in order to continue to disable quantization during calibration for GPTQ and SGPT * When [calibration pipelines become shared between modifiers](#1279), the decision of whether to disabling quantization during calibration will have to be moved to the calibration pipelines themselves. Some work needs to be done to demonstrate that GPTQ and SGPT do not suffer accuracy regression from enabling activation quantization during calibration (in theory, the change should increase accuracy) --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 30d45c5 commit 5a77b59

File tree

4 files changed

+8
-9
lines changed

4 files changed

+8
-9
lines changed

Diff for: src/llmcompressor/pipelines/layer_sequential/pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
maybe_inject_pos_embeddings,
1313
to_next_layer_kwargs,
1414
)
15-
from llmcompressor.utils.helpers import calibration_forward_context
15+
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
1616

1717
if TYPE_CHECKING:
1818
from llmcompressor.modifiers import Modifier
@@ -51,7 +51,7 @@ def run_pipeline(
5151
# find layers
5252
layers = match_modules(model, sequential_targets)
5353

54-
with calibration_forward_context(model):
54+
with calibration_forward_context(model), DisableQuantization(model):
5555
# prepare intermediates cache
5656
intermediates: IntermediatesCache = capture_first_layer_intermediates(
5757
model, layers[0], dataloader

Diff for: src/llmcompressor/pipelines/sequential/pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from llmcompressor.modifiers.utils.hooks import HooksMixin
99
from llmcompressor.pipelines.cache import IntermediatesCache
1010
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs
11-
from llmcompressor.utils.helpers import calibration_forward_context
11+
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
1212

1313
if TYPE_CHECKING:
1414
from llmcompressor.modifiers import Modifier
@@ -50,7 +50,7 @@ def run_pipeline(
5050
sample_input = next(iter(dataloader))
5151
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
5252

53-
with calibration_forward_context(model):
53+
with calibration_forward_context(model), DisableQuantization(model):
5454
# prepare intermediates cache
5555
model_device = get_execution_device(model)
5656
intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)

Diff for: src/llmcompressor/utils/helpers.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb):
10131013
@contextlib.contextmanager
10141014
def DisableQuantization(module: torch.nn.Module):
10151015
"""
1016-
Disable quantization from QuantizationModifier
1016+
Disable quantization during forward passes after applying a quantization config
10171017
"""
10181018
try:
10191019
module.apply(disable_quantization)
@@ -1040,13 +1040,11 @@ def calibration_forward_context(model: PreTrainedModel):
10401040
10411041
- Remove gradient calculations
10421042
- Disable the KV cache
1043-
- Disable quantization during forward pass
10441043
- Disable train mode and enable eval mode
10451044
"""
10461045
with (
10471046
torch.no_grad(),
10481047
DisableKVCache(model),
1049-
DisableQuantization(model),
10501048
eval_context(model),
10511049
):
10521050
yield

Diff for: tests/llmcompressor/utils/test_helpers.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ def test_calibration_forward_context():
134134
model = torch.nn.Linear(1, 1)
135135
model.config = SimpleNamespace()
136136
model.config.use_cache = True
137+
model.train()
137138

138139
with calibration_forward_context(model):
139140
assert not torch.is_grad_enabled()
140-
assert not model.quantization_enabled
141141
assert not model.config.use_cache
142+
assert not model.training
142143
assert torch.is_grad_enabled()
143-
assert model.quantization_enabled
144144
assert model.config.use_cache
145+
assert model.training

0 commit comments

Comments
 (0)