Skip to content

Commit f0b0305

Browse files
a-r-r-o-wXuehaiPan
andauthored
Fix for offloading when using TorchAO >= 0.7.0 (#3332)
* fix * update * fix * apply suggestions from review Co-Authored-By: Benjamin Bossan <[email protected]> Co-Authored-By: Xuehai Pan <[email protected]> * make style --------- Co-authored-by: Xuehai Pan <[email protected]>
1 parent 8097343 commit f0b0305

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/accelerate/utils/modeling.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import contextlib
1616
import gc
17+
import importlib
1718
import inspect
1819
import json
1920
import logging
@@ -43,7 +44,7 @@
4344
from .memory import clear_device_cache, get_xpu_available_memory
4445
from .offload import load_offloaded_weight, offload_weight, save_offload_index
4546
from .tqdm import is_tqdm_available, tqdm
46-
from .versions import is_torch_version
47+
from .versions import compare_versions, is_torch_version
4748

4849

4950
if is_npu_available(check_device=False):
@@ -350,17 +351,19 @@ def set_module_tensor_to_device(
350351
elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
351352
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
352353
elif param_cls.__name__ in ["AffineQuantizedTensor"]:
353-
new_value = torch.nn.Parameter(
354-
param_cls(
355-
new_value.layout_tensor,
356-
new_value.block_size,
357-
new_value.shape,
358-
new_value.quant_min,
359-
new_value.quant_max,
360-
new_value.zero_point_domain,
361-
),
362-
requires_grad=old_value.requires_grad,
363-
).to(device)
354+
if importlib.util.find_spec("torchao") is not None and compare_versions("torchao", ">=", "0.7.0"):
355+
# TorchAO v0.7.0 made layout_tensor an internal private variable and exposed tensor_impl
356+
args = (new_value.tensor_impl,)
357+
else:
358+
args = (new_value.layout_tensor,)
359+
args += (
360+
new_value.block_size,
361+
new_value.shape,
362+
new_value.quant_min,
363+
new_value.quant_max,
364+
new_value.zero_point_domain,
365+
)
366+
new_value = torch.nn.Parameter(param_cls(*args), requires_grad=old_value.requires_grad).to(device)
364367
else:
365368
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
366369

0 commit comments

Comments
 (0)