|
1 |
| -from typing import Optional, Dict, Any, Tuple |
2 |
| -from transformers import PreTrainedModel |
3 |
| -from ..quant.awq import AWQQuantizer |
4 |
| -from ..quant.gptq import GPTQQuantizer |
5 |
| -from ..quant.gguf import GGUFQuantizer |
6 |
| -from ..trainer.logger import TrainingLogger |
| 1 | +from typing import Optional, Dict, Any, Union, Tuple |
| 2 | +import torch |
| 3 | +from transformers import PreTrainedModel, AutoTokenizer |
| 4 | +from ..quant.gguf import GGUFQuantizer, SUPPORTED_GGUF_BITS, SUPPORTED_GGUF_TYPES |
| 5 | +from ..utils.logger import logger |
| 6 | +from ..utils.memory_tracker import memory_tracker |
| 7 | +from ..utils.benchmark import QuantizationBenchmark |
7 | 8 |
|
8 | 9 | class QuantLLM:
|
9 |
| - """High-level API for quantizing models using various methods.""" |
| 10 | + """High-level API for GGUF model quantization.""" |
10 | 11 | @staticmethod
|
11 | 12 | def quantize_from_pretrained(
|
12 |
| - model_name: str, |
13 |
| - method: str, |
14 |
| - quant_config_dict: Optional[Dict[str, Any]] = None, |
15 |
| - calibration_data: Optional[Any] = None, # Typically torch.Tensor or similar |
16 |
| - calibration_steps: Optional[int] = 100, # Specific to AWQ's quantize method |
17 |
| - device: Optional[str] = None # Explicit device control |
18 |
| - ) -> Tuple[PreTrainedModel, Any]: # Returns (quantized_model, tokenizer) |
| 13 | + model_name_or_path: Union[str, PreTrainedModel], |
| 14 | + bits: int = 4, |
| 15 | + group_size: int = 128, |
| 16 | + quant_type: Optional[str] = None, |
| 17 | + use_packed: bool = True, |
| 18 | + cpu_offload: bool = False, |
| 19 | + desc_act: bool = False, |
| 20 | + desc_ten: bool = False, |
| 21 | + legacy_format: bool = False, |
| 22 | + batch_size: int = 4, |
| 23 | + device: Optional[str] = None, |
| 24 | + calibration_data: Optional[torch.Tensor] = None, |
| 25 | + benchmark: bool = True, |
| 26 | + benchmark_input_shape: Optional[Tuple[int, ...]] = None, |
| 27 | + benchmark_steps: int = 100 |
| 28 | + ) -> Tuple[PreTrainedModel, Any]: |
19 | 29 | """
|
20 |
| - Loads a model from Hugging Face, quantizes it using the specified method, |
21 |
| - and returns the quantized model and its tokenizer. |
22 |
| -
|
23 |
| - Args: |
24 |
| - model_name_or_path (str): Hugging Face model ID or local path. |
25 |
| - method (str): Quantization method to use ('awq', 'gptq', 'gguf'). |
26 |
| - quant_config_dict (Optional[Dict[str, Any]]): Dictionary with quantization parameters. |
27 |
| - Common keys: 'bits', 'group_size', 'batch_size' (for quantizer init). |
28 |
| - AWQ specific: 'zero_point', 'awq_version' (maps to 'version' in AWQQuantizer). |
29 |
| - GPTQ specific: 'actorder', 'percdamp', 'sym'. |
30 |
| - GGUF specific: 'use_packed', 'cpu_offload', 'desc_act', 'desc_ten', 'legacy_format'. |
31 |
| - calibration_data (Optional[Any]): Calibration data required for quantization. |
32 |
| - calibration_steps (Optional[int]): Number of calibration steps, primarily for AWQ's |
33 |
| - quantize() method. Defaults to 100. |
34 |
| - device (Optional[str]): Device to run quantization on ('cpu', 'cuda', 'cuda:x'). |
35 |
| - If None, default device selection logic in BaseQuantizer is used. |
36 |
| - |
37 |
| - Returns: |
38 |
| - Tuple[PreTrainedModel, Any]: The quantized model and its associated tokenizer. |
39 |
| - |
40 |
| - Raises: |
41 |
| - ValueError: If an unsupported quantization method is specified or essential parameters are missing. |
42 |
| - RuntimeError: If quantization fails for some reason. |
| 30 | + Quantize a model using GGUF format with optional benchmarking. |
| 31 | + Returns (quantized_model, benchmark_results) |
43 | 32 | """
|
44 |
| - logger = TrainingLogger() |
45 |
| - if quant_config_dict is None: |
46 |
| - quant_config_dict = {} |
47 |
| - |
48 |
| - method_lower = method.lower() |
49 |
| - logger.log_info(f"Attempting to quantize model '{model_name}' using method: {method_lower}") |
50 |
| - |
51 |
| - bits = quant_config_dict.get('bits', 4) |
52 |
| - group_size = quant_config_dict.get('group_size', 128) |
53 |
| - quantizer_batch_size = quant_config_dict.get('batch_size', 4) |
54 |
| - |
55 |
| - quantizer = None |
56 |
| - |
57 |
| - if method_lower == 'awq': |
58 |
| - awq_zero_point = quant_config_dict.get('zero_point', True) |
59 |
| - awq_version = quant_config_dict.get('awq_version', 'v2') |
60 |
| - |
61 |
| - quantizer = AWQQuantizer( |
62 |
| - model_name=model_name, |
63 |
| - bits=bits, |
64 |
| - group_size=group_size, |
65 |
| - zero_point=awq_zero_point, |
66 |
| - version=awq_version, |
67 |
| - batch_size=quantizer_batch_size, |
68 |
| - device=device |
69 |
| - ) |
70 |
| - logger.log_info(f"Quantizing with AWQ... Bits: {bits}, Group Size: {group_size}, Zero Point: {awq_zero_point}, Version: {awq_version}") |
71 |
| - quantizer.quantize( # Call quantize, model is updated in place |
72 |
| - calibration_data=calibration_data, |
73 |
| - calibration_steps=calibration_steps |
74 |
| - ) |
75 |
| - |
76 |
| - elif method_lower == 'gptq': |
77 |
| - gptq_actorder = quant_config_dict.get('actorder', True) |
78 |
| - gptq_percdamp = quant_config_dict.get('percdamp', 0.01) |
79 |
| - gptq_sym = quant_config_dict.get('sym', True) |
80 |
| - |
81 |
| - quantizer = GPTQQuantizer( |
82 |
| - model_name=model_name, |
83 |
| - bits=bits, |
84 |
| - group_size=group_size, |
85 |
| - actorder=gptq_actorder, |
86 |
| - percdamp=gptq_percdamp, |
87 |
| - sym=gptq_sym, |
88 |
| - batch_size=quantizer_batch_size, |
89 |
| - device=device |
90 |
| - ) |
91 |
| - logger.log_info(f"Quantizing with GPTQ... Bits: {bits}, Group Size: {group_size}, ActOrder: {gptq_actorder}, Sym: {gptq_sym}") |
92 |
| - quantizer.quantize(calibration_data=calibration_data) # Model updated in place |
93 |
| - |
94 |
| - elif method_lower == 'gguf': |
95 |
| - gguf_use_packed = quant_config_dict.get('use_packed', True) |
96 |
| - gguf_cpu_offload = quant_config_dict.get('cpu_offload', False) |
97 |
| - gguf_desc_act = quant_config_dict.get('desc_act', False) |
98 |
| - gguf_desc_ten = quant_config_dict.get('desc_ten', False) |
99 |
| - gguf_legacy_format = quant_config_dict.get('legacy_format', False) |
100 |
| - |
| 33 | + try: |
| 34 | + logger.log_info(f"Starting GGUF quantization with {bits} bits") |
| 35 | + memory_tracker.log_memory("quantization_start") |
| 36 | + if bits not in SUPPORTED_GGUF_BITS: |
| 37 | + raise ValueError(f"Unsupported bits: {bits}. Supported values: {SUPPORTED_GGUF_BITS}") |
| 38 | + if quant_type and quant_type not in SUPPORTED_GGUF_TYPES.get(bits, []): |
| 39 | + raise ValueError(f"Unsupported quant_type: {quant_type} for {bits} bits") |
101 | 40 | quantizer = GGUFQuantizer(
|
102 |
| - model_name=model_name, |
| 41 | + model_name=model_name_or_path, |
103 | 42 | bits=bits,
|
104 | 43 | group_size=group_size,
|
105 |
| - use_packed=gguf_use_packed, |
106 |
| - cpu_offload=gguf_cpu_offload, |
107 |
| - desc_act=gguf_desc_act, |
108 |
| - desc_ten=gguf_desc_ten, |
109 |
| - legacy_format=gguf_legacy_format, |
110 |
| - batch_size=quantizer_batch_size, |
| 44 | + quant_type=quant_type, |
| 45 | + use_packed=use_packed, |
| 46 | + cpu_offload=cpu_offload, |
| 47 | + desc_act=desc_act, |
| 48 | + desc_ten=desc_ten, |
| 49 | + legacy_format=legacy_format, |
| 50 | + batch_size=batch_size, |
111 | 51 | device=device
|
112 | 52 | )
|
113 |
| - logger.log_info(f"Quantizing with GGUF... Bits: {bits}, Group Size: {group_size}, Packed: {gguf_use_packed}, CPU Offload: {gguf_cpu_offload}") |
114 |
| - quantizer.quantize(calibration_data=calibration_data) # Model updated in place |
| 53 | + logger.log_info("Starting quantization process") |
| 54 | + quantized_model = quantizer.quantize(calibration_data) |
| 55 | + memory_tracker.log_memory("quantization_complete") |
| 56 | + benchmark_results = {} |
| 57 | + if benchmark: |
| 58 | + logger.log_info("Running benchmarks") |
| 59 | + if not benchmark_input_shape: |
| 60 | + if hasattr(quantized_model.config, 'max_position_embeddings'): |
| 61 | + seq_len = min(32, quantized_model.config.max_position_embeddings) |
| 62 | + else: |
| 63 | + seq_len = 32 |
| 64 | + benchmark_input_shape = (1, seq_len) |
| 65 | + benchmarker = QuantizationBenchmark( |
| 66 | + model=quantized_model, |
| 67 | + calibration_data=calibration_data, |
| 68 | + input_shape=benchmark_input_shape, |
| 69 | + num_inference_steps=benchmark_steps, |
| 70 | + device=device |
| 71 | + ) |
| 72 | + benchmark_results = benchmarker.run_all_benchmarks() |
| 73 | + memory_tracker.log_memory("benchmarking_complete") |
| 74 | + logger.log_info("Benchmark Results:") |
| 75 | + if hasattr(benchmark_results, 'to_dict'): |
| 76 | + benchmark_results = benchmark_results.to_dict() |
| 77 | + for metric, value in (benchmark_results.items() if isinstance(benchmark_results, dict) else []): |
| 78 | + logger.log_info(f"{metric}: {value}") |
| 79 | + return quantized_model, benchmark_results |
| 80 | + except Exception as e: |
| 81 | + logger.log_error(f"Quantization failed: {str(e)}") |
| 82 | + raise |
| 83 | + finally: |
| 84 | + memory_tracker.clear_memory() |
115 | 85 |
|
116 |
| - else: |
117 |
| - logger.log_error(f"Unsupported quantization method: {method}") |
118 |
| - raise ValueError(f"Unsupported quantization method: {method}. Supported methods are 'awq', 'gptq', 'gguf'.") |
119 | 86 |
|
120 |
| - if quantizer is None or quantizer.model is None: |
121 |
| - logger.log_error(f"Failed to initialize quantizer or obtain quantized model for method: {method}") |
122 |
| - raise RuntimeError(f"Quantization failed for method: {method}. Quantizer or model is None.") |
123 | 87 |
|
124 |
| - logger.log_info(f"Successfully quantized model with method: {method_lower}") |
125 |
| - return quantizer.model, quantizer.tokenizer |
| 88 | + @staticmethod |
| 89 | + def save_quantized_model( |
| 90 | + model: PreTrainedModel, |
| 91 | + output_path: str, |
| 92 | + save_tokenizer: bool = True |
| 93 | + ): |
| 94 | + """ |
| 95 | + Save a quantized model and optionally its tokenizer. |
| 96 | + |
| 97 | + Args: |
| 98 | + model: Quantized model to save |
| 99 | + output_path: Path to save the model |
| 100 | + save_tokenizer: Whether to save the tokenizer |
| 101 | + """ |
| 102 | + try: |
| 103 | + logger.log_info(f"Saving quantized model to {output_path}") |
| 104 | + memory_tracker.log_memory("save_start") |
| 105 | + |
| 106 | + # Save model |
| 107 | + model.save_pretrained(output_path) |
| 108 | + |
| 109 | + # Save tokenizer if requested |
| 110 | + if save_tokenizer and hasattr(model, 'config'): |
| 111 | + if hasattr(model.config, '_name_or_path'): |
| 112 | + try: |
| 113 | + tokenizer = AutoTokenizer.from_pretrained( |
| 114 | + model.config._name_or_path, |
| 115 | + trust_remote_code=True |
| 116 | + ) |
| 117 | + tokenizer.save_pretrained(output_path) |
| 118 | + logger.log_info("Tokenizer saved successfully") |
| 119 | + except Exception as e: |
| 120 | + logger.log_warning(f"Failed to save tokenizer: {e}") |
| 121 | + |
| 122 | + memory_tracker.log_memory("save_complete") |
| 123 | + logger.log_info("Model saved successfully") |
| 124 | + |
| 125 | + except Exception as e: |
| 126 | + logger.log_error(f"Failed to save model: {str(e)}") |
| 127 | + raise |
| 128 | + finally: |
| 129 | + memory_tracker.clear_memory() |
| 130 | + |
| 131 | + @staticmethod |
| 132 | + def convert_to_gguf( |
| 133 | + model: PreTrainedModel, |
| 134 | + output_path: str, |
| 135 | + quant_config: Optional[Dict[str, Any]] = None |
| 136 | + ): |
| 137 | + """ |
| 138 | + Convert a quantized model to GGUF format. |
| 139 | + |
| 140 | + Args: |
| 141 | + model: Model to convert |
| 142 | + output_path: Path to save GGUF file |
| 143 | + quant_config: Optional quantization configuration |
| 144 | + """ |
| 145 | + try: |
| 146 | + logger.log_info(f"Converting model to GGUF format: {output_path}") |
| 147 | + memory_tracker.log_memory("conversion_start") |
| 148 | + |
| 149 | + # Get quantization config from model if not provided |
| 150 | + if not quant_config and hasattr(model.config, 'quantization_config'): |
| 151 | + quant_config = model.config.quantization_config |
| 152 | + |
| 153 | + # Create quantizer with existing or default config |
| 154 | + quantizer = GGUFQuantizer( |
| 155 | + model_name=model, |
| 156 | + bits=quant_config.get('bits', 4) if quant_config else 4, |
| 157 | + group_size=quant_config.get('group_size', 128) if quant_config else 128, |
| 158 | + quant_type=quant_config.get('quant_type', None) if quant_config else None, |
| 159 | + use_packed=quant_config.get('use_packed', True) if quant_config else True |
| 160 | + ) |
| 161 | + |
| 162 | + # Convert to GGUF |
| 163 | + quantizer.convert_to_gguf(output_path) |
| 164 | + memory_tracker.log_memory("conversion_complete") |
| 165 | + logger.log_info("GGUF conversion completed successfully") |
| 166 | + |
| 167 | + except Exception as e: |
| 168 | + logger.log_error(f"GGUF conversion failed: {str(e)}") |
| 169 | + raise |
| 170 | + finally: |
| 171 | + memory_tracker.clear_memory() |
0 commit comments