1
+ """High-level API for model quantization."""
2
+
3
+ from typing import Optional , Dict , Any , Union , Tuple
4
+ import torch
5
+ from transformers import PreTrainedModel , AutoTokenizer
6
+ from ..quant .gguf import GGUFQuantizer , SUPPORTED_GGUF_BITS , SUPPORTED_GGUF_TYPES
7
+ from ..utils .logger import logger
8
+ from ..utils .memory_tracker import memory_tracker
9
+ from ..utils .benchmark import QuantizationBenchmark
10
+
11
+ class QuantizationAPI :
12
+ """High-level API for model quantization with GGUF support."""
13
+
14
+ @staticmethod
15
+ def quantize_model (
16
+ model_name_or_path : Union [str , PreTrainedModel ],
17
+ bits : int = 4 ,
18
+ group_size : int = 128 ,
19
+ quant_type : Optional [str ] = None ,
20
+ use_packed : bool = True ,
21
+ cpu_offload : bool = False ,
22
+ desc_act : bool = False ,
23
+ desc_ten : bool = False ,
24
+ legacy_format : bool = False ,
25
+ batch_size : int = 4 ,
26
+ device : Optional [str ] = None ,
27
+ calibration_data : Optional [torch .Tensor ] = None ,
28
+ benchmark : bool = True ,
29
+ benchmark_input_shape : Optional [Tuple [int , ...]] = None ,
30
+ benchmark_steps : int = 100
31
+ ) -> Tuple [PreTrainedModel , Dict [str , Any ]]:
32
+ """
33
+ Quantize a model using GGUF format with optional benchmarking.
34
+
35
+ Args:
36
+ model_name_or_path: Model identifier or instance
37
+ bits: Number of bits for quantization
38
+ group_size: Size of quantization groups
39
+ quant_type: GGUF quantization type
40
+ use_packed: Whether to use packed format
41
+ cpu_offload: Whether to offload to CPU during quantization
42
+ desc_act: Whether to include activation descriptors
43
+ desc_ten: Whether to include tensor descriptors
44
+ legacy_format: Whether to use legacy format
45
+ batch_size: Batch size for processing
46
+ device: Device for quantization
47
+ calibration_data: Optional calibration data
48
+ benchmark: Whether to run benchmarks
49
+ benchmark_input_shape: Shape for benchmark inputs
50
+ benchmark_steps: Number of benchmark steps
51
+
52
+ Returns:
53
+ Tuple of (quantized model, benchmark results)
54
+ """
55
+ try :
56
+ logger .log_info (f"Starting model quantization with { bits } bits" )
57
+ memory_tracker .log_memory ("quantization_start" )
58
+
59
+ # Validate parameters
60
+ if bits not in SUPPORTED_GGUF_BITS :
61
+ raise ValueError (f"Unsupported bits: { bits } . Supported values: { SUPPORTED_GGUF_BITS } " )
62
+
63
+ if quant_type and quant_type not in SUPPORTED_GGUF_TYPES .get (bits , []):
64
+ raise ValueError (f"Unsupported quant_type: { quant_type } for { bits } bits" )
65
+
66
+ # Initialize quantizer
67
+ quantizer = GGUFQuantizer (
68
+ model_name = model_name_or_path ,
69
+ bits = bits ,
70
+ group_size = group_size ,
71
+ quant_type = quant_type ,
72
+ use_packed = use_packed ,
73
+ cpu_offload = cpu_offload ,
74
+ desc_act = desc_act ,
75
+ desc_ten = desc_ten ,
76
+ legacy_format = legacy_format ,
77
+ batch_size = batch_size ,
78
+ device = device
79
+ )
80
+
81
+ # Perform quantization
82
+ logger .log_info ("Starting quantization process" )
83
+ quantized_model = quantizer .quantize (calibration_data )
84
+ memory_tracker .log_memory ("quantization_complete" )
85
+
86
+ # Run benchmarks if requested
87
+ benchmark_results = {}
88
+ if benchmark :
89
+ logger .log_info ("Running benchmarks" )
90
+ if not benchmark_input_shape :
91
+ # Default shape based on model config
92
+ if hasattr (quantized_model .config , 'max_position_embeddings' ):
93
+ seq_len = min (32 , quantized_model .config .max_position_embeddings )
94
+ else :
95
+ seq_len = 32
96
+ benchmark_input_shape = (1 , seq_len )
97
+
98
+ benchmarker = QuantizationBenchmark (
99
+ model = quantized_model ,
100
+ calibration_data = calibration_data ,
101
+ input_shape = benchmark_input_shape ,
102
+ num_inference_steps = benchmark_steps ,
103
+ device = device
104
+ )
105
+
106
+ benchmark_results = benchmarker .run_all_benchmarks ()
107
+ memory_tracker .log_memory ("benchmarking_complete" )
108
+
109
+ # Log benchmark summary
110
+ logger .log_info ("Benchmark Results:" )
111
+ for metric , value in benchmark_results .items ():
112
+ logger .log_info (f"{ metric } : { value } " )
113
+
114
+ return quantized_model , benchmark_results
115
+
116
+ except Exception as e :
117
+ logger .log_error (f"Quantization failed: { str (e )} " )
118
+ raise
119
+ finally :
120
+ memory_tracker .clear_memory ()
121
+
122
+ @staticmethod
123
+ def save_quantized_model (
124
+ model : PreTrainedModel ,
125
+ output_path : str ,
126
+ save_tokenizer : bool = True
127
+ ):
128
+ """
129
+ Save a quantized model and optionally its tokenizer.
130
+
131
+ Args:
132
+ model: Quantized model to save
133
+ output_path: Path to save the model
134
+ save_tokenizer: Whether to save the tokenizer
135
+ """
136
+ try :
137
+ logger .log_info (f"Saving quantized model to { output_path } " )
138
+ memory_tracker .log_memory ("save_start" )
139
+
140
+ # Save model
141
+ model .save_pretrained (output_path )
142
+
143
+ # Save tokenizer if requested
144
+ if save_tokenizer and hasattr (model , 'config' ):
145
+ if hasattr (model .config , '_name_or_path' ):
146
+ try :
147
+ tokenizer = AutoTokenizer .from_pretrained (
148
+ model .config ._name_or_path ,
149
+ trust_remote_code = True
150
+ )
151
+ tokenizer .save_pretrained (output_path )
152
+ logger .log_info ("Tokenizer saved successfully" )
153
+ except Exception as e :
154
+ logger .log_warning (f"Failed to save tokenizer: { e } " )
155
+
156
+ memory_tracker .log_memory ("save_complete" )
157
+ logger .log_info ("Model saved successfully" )
158
+
159
+ except Exception as e :
160
+ logger .log_error (f"Failed to save model: { str (e )} " )
161
+ raise
162
+ finally :
163
+ memory_tracker .clear_memory ()
164
+
165
+ @staticmethod
166
+ def convert_to_gguf (
167
+ model : PreTrainedModel ,
168
+ output_path : str ,
169
+ quant_config : Optional [Dict [str , Any ]] = None
170
+ ):
171
+ """
172
+ Convert a quantized model to GGUF format.
173
+
174
+ Args:
175
+ model: Model to convert
176
+ output_path: Path to save GGUF file
177
+ quant_config: Optional quantization configuration
178
+ """
179
+ try :
180
+ logger .log_info (f"Converting model to GGUF format: { output_path } " )
181
+ memory_tracker .log_memory ("conversion_start" )
182
+
183
+ # Get quantization config from model if not provided
184
+ if not quant_config and hasattr (model .config , 'quantization_config' ):
185
+ quant_config = model .config .quantization_config
186
+
187
+ # Create quantizer with existing or default config
188
+ quantizer = GGUFQuantizer (
189
+ model_name = model ,
190
+ bits = quant_config .get ('bits' , 4 ) if quant_config else 4 ,
191
+ group_size = quant_config .get ('group_size' , 128 ) if quant_config else 128 ,
192
+ quant_type = quant_config .get ('quant_type' , None ) if quant_config else None ,
193
+ use_packed = quant_config .get ('use_packed' , True ) if quant_config else True
194
+ )
195
+
196
+ # Convert to GGUF
197
+ quantizer .convert_to_gguf (output_path )
198
+ memory_tracker .log_memory ("conversion_complete" )
199
+ logger .log_info ("GGUF conversion completed successfully" )
200
+
201
+ except Exception as e :
202
+ logger .log_error (f"GGUF conversion failed: { str (e )} " )
203
+ raise
204
+ finally :
205
+ memory_tracker .clear_memory ()
0 commit comments