Skip to content

Commit bc09b0d

Browse files
authored
Restore compatibility with PT 1.9.1 (openvinotoolkit#1806)
### Changes If running on Torch < 1.12, do not use `register_load_state_dict_post_hook` and do level range updates on each forward instead. ### Reason for changes E2E tests for torch 1.9.1 failing, because `torch.nn.Module.register_load_state_dict_post_hook` has only been added in 1.12. ### Related tickets N/A ### Tests E2E PT 1.9.1 run TBA
1 parent b07b652 commit bc09b0d

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

nncf/torch/quantization/layers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import torch
17+
from pkg_resources import parse_version
1718
from torch import distributed
1819
from torch import nn
1920

@@ -322,6 +323,7 @@ def close(self):
322323
self.hook.remove()
323324

324325
self.load_listener = LoadStateListener(self)
326+
self._old_level_range_setting = False
325327

326328
def enable_gradients(self):
327329
raise NotImplementedError
@@ -347,6 +349,8 @@ def forward(self, x):
347349
# TODO: refactor to get rid of extra if's and calls on each forward
348350
if not self.is_enabled_quantization():
349351
return x
352+
if self._old_level_range_setting:
353+
self.set_level_ranges()
350354
is_exporting = is_tracing_state()
351355
if is_exporting:
352356
with no_nncf_trace():
@@ -614,8 +618,11 @@ def __init__(self, qspec: PTQuantizerSpec):
614618
)
615619
)
616620

617-
# Values of level_low, level_high must be recalculated for load new signed parameter.
618-
self.register_load_state_dict_post_hook(lambda module, _: module.set_level_ranges())
621+
if parse_version(torch.__version__) >= parse_version("1.12"):
622+
# Values of level_low, level_high must be recalculated for load new signed parameter.
623+
self.register_load_state_dict_post_hook(lambda module, _: module.set_level_ranges())
624+
else:
625+
self._old_level_range_setting = True
619626

620627
@property
621628
def scale(self):

0 commit comments

Comments
 (0)