|
52 | 52 | mv_module_from_gpu, |
53 | 53 | unsupport_meta_device, clear_memory, |
54 | 54 | compile_func, |
55 | | - find_matching_blocks, is_debug_mode |
| 55 | + find_matching_blocks, is_debug_mode, |
| 56 | + TORCH_VERSION_AT_LEAST_2_6 |
56 | 57 | ) |
57 | 58 | from .low_cpu_mem.utils import get_layers_before_block |
58 | 59 |
|
@@ -159,7 +160,7 @@ def __init__( |
159 | 160 | act_dynamic: bool = True, |
160 | 161 | to_quant_block_names: Union[str, list] = None, |
161 | 162 | enable_norm_bias_tuning: bool = False, |
162 | | - enable_torch_compile: bool = None, |
| 163 | + enable_torch_compile: bool = False, |
163 | 164 | device_map: Union[str, dict] = None, |
164 | 165 | **kwargs, |
165 | 166 | ): |
@@ -232,19 +233,24 @@ def __init__( |
232 | 233 | logger.info(f"using {self.model.dtype} for quantization tuning") |
233 | 234 |
|
234 | 235 | self.enable_torch_compile = enable_torch_compile |
235 | | - if self.act_bits <= 8 and self.enable_torch_compile != False: |
| 236 | + if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \ |
| 237 | + and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type: |
| 238 | + logger.info("'enable_torch_compile' is set to `False` by default. " \ |
| 239 | + "Enabling it can reduce tuning cost by 20%, but it might throw an exception.") |
| 240 | + |
| 241 | + if self.act_bits <= 8 and self.enable_torch_compile: |
236 | 242 | self.enable_torch_compile = False |
237 | 243 | logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled") |
238 | 244 |
|
239 | | - if self.low_cpu_mem_usage == True and self.enable_torch_compile != False: |
| 245 | + if self.low_cpu_mem_usage == True and self.enable_torch_compile: |
240 | 246 | self.enable_torch_compile = False |
241 | 247 | logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled") |
242 | 248 |
|
243 | | - if is_debug_mode() and self.enable_torch_compile != False: |
| 249 | + if is_debug_mode() and self.enable_torch_compile: |
244 | 250 | self.enable_torch_compile = False |
245 | 251 | logger.warning("reset enable_torch_compile to `False` as debug mode is enabled") |
246 | 252 |
|
247 | | - if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile != False: |
| 253 | + if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile: |
248 | 254 | self.enable_torch_compile = False |
249 | 255 | logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") |
250 | 256 |
|
@@ -493,13 +499,8 @@ def quant_layers(self, layer_names, layer_inputs): |
493 | 499 | self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) |
494 | 500 | clear_memory() |
495 | 501 | device = next(self.model.parameters()).device |
496 | | - if self.enable_torch_compile != False: |
497 | | - try: |
498 | | - quant_layer = compile_func(self.quant_layer, self.device, self.enable_torch_compile) |
499 | | - except: |
500 | | - logger.warning("torch compile failed, reset it to `False`") |
501 | | - self.enable_torch_compile = False |
502 | | - quant_layer = self.quant_layer |
| 502 | + if self.enable_torch_compile: |
| 503 | + quant_layer = compile_func(self.quant_layer, self.device) |
503 | 504 | else: |
504 | 505 | quant_layer = self.quant_layer |
505 | 506 | for layer_name in layer_names: |
@@ -1311,13 +1312,8 @@ def quant_blocks( |
1311 | 1312 | elif isinstance(input_others[key], list): |
1312 | 1313 | for i in range(len(input_others[key])): |
1313 | 1314 | to_dtype(input_others[key][i], tmp_dtype) |
1314 | | - if self.enable_torch_compile != False: |
1315 | | - try: |
1316 | | - quant_block = compile_func(self.quant_block, device, self.enable_torch_compile) |
1317 | | - except: |
1318 | | - logger.warning("torch compile failed, reset it to `False`") |
1319 | | - self.enable_torch_compile = False |
1320 | | - quant_block = self.quant_block |
| 1315 | + if self.enable_torch_compile: |
| 1316 | + quant_block = compile_func(self.quant_block, device) |
1321 | 1317 | else: |
1322 | 1318 | quant_block = self.quant_block |
1323 | 1319 |
|
@@ -1648,7 +1644,7 @@ def __init__( |
1648 | 1644 | act_dynamic: bool = True, |
1649 | 1645 | to_quant_block_names: Union[str, list] = None, |
1650 | 1646 | enable_norm_bias_tuning: bool = False, |
1651 | | - enable_torch_compile: bool = None, |
| 1647 | + enable_torch_compile: bool = False, |
1652 | 1648 | device_map: Union[str, dict] = None, |
1653 | 1649 | optimizer="AdamW", |
1654 | 1650 | **kwargs, |
@@ -1822,7 +1818,7 @@ def __init__( |
1822 | 1818 | act_dynamic: bool = True, |
1823 | 1819 | to_quant_block_names: Union[str, list] = None, |
1824 | 1820 | enable_norm_bias_tuning: bool = False, |
1825 | | - enable_torch_compile: bool = None, |
| 1821 | + enable_torch_compile: bool = False, |
1826 | 1822 | device_map: Union[str, dict] = None, |
1827 | 1823 | optimizer="AdamW", |
1828 | 1824 | **kwargs, |
@@ -1868,3 +1864,4 @@ def __init__( |
1868 | 1864 | optimizer=optimizer, |
1869 | 1865 | **kwargs, |
1870 | 1866 | ) |
| 1867 | + |
0 commit comments