|
14 | 14 |
|
15 | 15 | import contextlib
|
16 | 16 | import gc
|
| 17 | +import importlib |
17 | 18 | import inspect
|
18 | 19 | import json
|
19 | 20 | import logging
|
|
43 | 44 | from .memory import clear_device_cache, get_xpu_available_memory
|
44 | 45 | from .offload import load_offloaded_weight, offload_weight, save_offload_index
|
45 | 46 | from .tqdm import is_tqdm_available, tqdm
|
46 |
| -from .versions import is_torch_version |
| 47 | +from .versions import compare_versions, is_torch_version |
47 | 48 |
|
48 | 49 |
|
49 | 50 | if is_npu_available(check_device=False):
|
@@ -350,17 +351,19 @@ def set_module_tensor_to_device(
|
350 | 351 | elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
|
351 | 352 | new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
|
352 | 353 | elif param_cls.__name__ in ["AffineQuantizedTensor"]:
|
353 |
| - new_value = torch.nn.Parameter( |
354 |
| - param_cls( |
355 |
| - new_value.layout_tensor, |
356 |
| - new_value.block_size, |
357 |
| - new_value.shape, |
358 |
| - new_value.quant_min, |
359 |
| - new_value.quant_max, |
360 |
| - new_value.zero_point_domain, |
361 |
| - ), |
362 |
| - requires_grad=old_value.requires_grad, |
363 |
| - ).to(device) |
| 354 | + if importlib.util.find_spec("torchao") is not None and compare_versions("torchao", ">=", "0.7.0"): |
| 355 | + # TorchAO v0.7.0 made layout_tensor an internal private variable and exposed tensor_impl |
| 356 | + args = (new_value.tensor_impl,) |
| 357 | + else: |
| 358 | + args = (new_value.layout_tensor,) |
| 359 | + args += ( |
| 360 | + new_value.block_size, |
| 361 | + new_value.shape, |
| 362 | + new_value.quant_min, |
| 363 | + new_value.quant_max, |
| 364 | + new_value.zero_point_domain, |
| 365 | + ) |
| 366 | + new_value = torch.nn.Parameter(param_cls(*args), requires_grad=old_value.requires_grad).to(device) |
364 | 367 | else:
|
365 | 368 | new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
|
366 | 369 |
|
|
0 commit comments