Skip to content

Commit aa7496d

Browse files
committed
use calib forward context
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b9bd970 commit aa7496d

File tree

2 files changed

+16
-21
lines changed
  • src/llmcompressor/modifiers

2 files changed

+16
-21
lines changed

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
run_calibration_forward,
3333
)
3434
from llmcompressor.observers.helpers import get_observer_token_count
35+
from llmcompressor.utils.helpers import calibration_forward_context
3536

3637
__all__ = ["QuantizationModifier"]
3738

@@ -309,18 +310,13 @@ def _calibrate(self, module: Module):
309310
f"{len(self.calibration_dataloader_)} samples..."
310311
)
311312

312-
module_training = module.training
313-
module.eval()
314-
315-
run_calibration_forward(
316-
module,
317-
self.calibration_dataloader_,
318-
self.num_calibration_steps,
319-
self.calibration_function_,
320-
)
321-
322-
if module_training:
323-
module.train()
313+
with calibration_forward_context(module):
314+
run_calibration_forward(
315+
module,
316+
self.calibration_dataloader_,
317+
self.num_calibration_steps,
318+
self.calibration_function_,
319+
)
324320

325321
def _check_token_distribution(
326322
self, model: Module, threshold: Optional[float] = None

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
1616
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
17+
from llmcompressor.utils.helpers import calibration_forward_context
1718
from llmcompressor.utils.pytorch.module import (
1819
get_layers,
1920
get_matching_layer,
@@ -250,12 +251,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
250251
" CompressionSession to run the SmoothQuant modifier"
251252
)
252253

253-
run_calibration_forward(
254-
model,
255-
calibration_dataloader,
256-
self.num_calibration_steps,
257-
self.calibration_function,
258-
)
254+
with calibration_forward_context(model):
255+
run_calibration_forward(
256+
model,
257+
calibration_dataloader,
258+
self.num_calibration_steps,
259+
self.calibration_function,
260+
)
259261

260262
# remove the hooks now that we are done calibrating
261263
self.remove_hooks()
@@ -313,9 +315,6 @@ def smooth(module):
313315
smooth(layer)
314316
smooth(smooth_layer)
315317

316-
# clear out allocated smoothing scales
317-
torch.cuda.empty_cache()
318-
319318
def _calculate_smoothing_scales(
320319
self, balance_layers: List[Module], activation_scales: torch.Tensor
321320
) -> List[float]:

0 commit comments

Comments
 (0)