Skip to content

Commit 2a97a9a

Browse files
Merge pull request #12 from codewithdark-git/feature/add_GGUF
Add the GGUF for Quantization
2 parents 4e22cbe + b91fd7b commit 2a97a9a

16 files changed

+353
-1391
lines changed

quantllm/api/high_level.py

Lines changed: 157 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,171 @@
1-
from typing import Optional, Dict, Any, Tuple
2-
from transformers import PreTrainedModel
3-
from ..quant.awq import AWQQuantizer
4-
from ..quant.gptq import GPTQQuantizer
5-
from ..quant.gguf import GGUFQuantizer
6-
from ..trainer.logger import TrainingLogger
1+
from typing import Optional, Dict, Any, Union, Tuple
2+
import torch
3+
from transformers import PreTrainedModel, AutoTokenizer
4+
from ..quant.gguf import GGUFQuantizer, SUPPORTED_GGUF_BITS, SUPPORTED_GGUF_TYPES
5+
from ..utils.logger import logger
6+
from ..utils.memory_tracker import memory_tracker
7+
from ..utils.benchmark import QuantizationBenchmark
78

89
class QuantLLM:
9-
"""High-level API for quantizing models using various methods."""
10+
"""High-level API for GGUF model quantization."""
1011
@staticmethod
1112
def quantize_from_pretrained(
12-
model_name: str,
13-
method: str,
14-
quant_config_dict: Optional[Dict[str, Any]] = None,
15-
calibration_data: Optional[Any] = None, # Typically torch.Tensor or similar
16-
calibration_steps: Optional[int] = 100, # Specific to AWQ's quantize method
17-
device: Optional[str] = None # Explicit device control
18-
) -> Tuple[PreTrainedModel, Any]: # Returns (quantized_model, tokenizer)
13+
model_name_or_path: Union[str, PreTrainedModel],
14+
bits: int = 4,
15+
group_size: int = 128,
16+
quant_type: Optional[str] = None,
17+
use_packed: bool = True,
18+
cpu_offload: bool = False,
19+
desc_act: bool = False,
20+
desc_ten: bool = False,
21+
legacy_format: bool = False,
22+
batch_size: int = 4,
23+
device: Optional[str] = None,
24+
calibration_data: Optional[torch.Tensor] = None,
25+
benchmark: bool = True,
26+
benchmark_input_shape: Optional[Tuple[int, ...]] = None,
27+
benchmark_steps: int = 100
28+
) -> Tuple[PreTrainedModel, Any]:
1929
"""
20-
Loads a model from Hugging Face, quantizes it using the specified method,
21-
and returns the quantized model and its tokenizer.
22-
23-
Args:
24-
model_name_or_path (str): Hugging Face model ID or local path.
25-
method (str): Quantization method to use ('awq', 'gptq', 'gguf').
26-
quant_config_dict (Optional[Dict[str, Any]]): Dictionary with quantization parameters.
27-
Common keys: 'bits', 'group_size', 'batch_size' (for quantizer init).
28-
AWQ specific: 'zero_point', 'awq_version' (maps to 'version' in AWQQuantizer).
29-
GPTQ specific: 'actorder', 'percdamp', 'sym'.
30-
GGUF specific: 'use_packed', 'cpu_offload', 'desc_act', 'desc_ten', 'legacy_format'.
31-
calibration_data (Optional[Any]): Calibration data required for quantization.
32-
calibration_steps (Optional[int]): Number of calibration steps, primarily for AWQ's
33-
quantize() method. Defaults to 100.
34-
device (Optional[str]): Device to run quantization on ('cpu', 'cuda', 'cuda:x').
35-
If None, default device selection logic in BaseQuantizer is used.
36-
37-
Returns:
38-
Tuple[PreTrainedModel, Any]: The quantized model and its associated tokenizer.
39-
40-
Raises:
41-
ValueError: If an unsupported quantization method is specified or essential parameters are missing.
42-
RuntimeError: If quantization fails for some reason.
30+
Quantize a model using GGUF format with optional benchmarking.
31+
Returns (quantized_model, benchmark_results)
4332
"""
44-
logger = TrainingLogger()
45-
if quant_config_dict is None:
46-
quant_config_dict = {}
47-
48-
method_lower = method.lower()
49-
logger.log_info(f"Attempting to quantize model '{model_name}' using method: {method_lower}")
50-
51-
bits = quant_config_dict.get('bits', 4)
52-
group_size = quant_config_dict.get('group_size', 128)
53-
quantizer_batch_size = quant_config_dict.get('batch_size', 4)
54-
55-
quantizer = None
56-
57-
if method_lower == 'awq':
58-
awq_zero_point = quant_config_dict.get('zero_point', True)
59-
awq_version = quant_config_dict.get('awq_version', 'v2')
60-
61-
quantizer = AWQQuantizer(
62-
model_name=model_name,
63-
bits=bits,
64-
group_size=group_size,
65-
zero_point=awq_zero_point,
66-
version=awq_version,
67-
batch_size=quantizer_batch_size,
68-
device=device
69-
)
70-
logger.log_info(f"Quantizing with AWQ... Bits: {bits}, Group Size: {group_size}, Zero Point: {awq_zero_point}, Version: {awq_version}")
71-
quantizer.quantize( # Call quantize, model is updated in place
72-
calibration_data=calibration_data,
73-
calibration_steps=calibration_steps
74-
)
75-
76-
elif method_lower == 'gptq':
77-
gptq_actorder = quant_config_dict.get('actorder', True)
78-
gptq_percdamp = quant_config_dict.get('percdamp', 0.01)
79-
gptq_sym = quant_config_dict.get('sym', True)
80-
81-
quantizer = GPTQQuantizer(
82-
model_name=model_name,
83-
bits=bits,
84-
group_size=group_size,
85-
actorder=gptq_actorder,
86-
percdamp=gptq_percdamp,
87-
sym=gptq_sym,
88-
batch_size=quantizer_batch_size,
89-
device=device
90-
)
91-
logger.log_info(f"Quantizing with GPTQ... Bits: {bits}, Group Size: {group_size}, ActOrder: {gptq_actorder}, Sym: {gptq_sym}")
92-
quantizer.quantize(calibration_data=calibration_data) # Model updated in place
93-
94-
elif method_lower == 'gguf':
95-
gguf_use_packed = quant_config_dict.get('use_packed', True)
96-
gguf_cpu_offload = quant_config_dict.get('cpu_offload', False)
97-
gguf_desc_act = quant_config_dict.get('desc_act', False)
98-
gguf_desc_ten = quant_config_dict.get('desc_ten', False)
99-
gguf_legacy_format = quant_config_dict.get('legacy_format', False)
100-
33+
try:
34+
logger.log_info(f"Starting GGUF quantization with {bits} bits")
35+
memory_tracker.log_memory("quantization_start")
36+
if bits not in SUPPORTED_GGUF_BITS:
37+
raise ValueError(f"Unsupported bits: {bits}. Supported values: {SUPPORTED_GGUF_BITS}")
38+
if quant_type and quant_type not in SUPPORTED_GGUF_TYPES.get(bits, []):
39+
raise ValueError(f"Unsupported quant_type: {quant_type} for {bits} bits")
10140
quantizer = GGUFQuantizer(
102-
model_name=model_name,
41+
model_name=model_name_or_path,
10342
bits=bits,
10443
group_size=group_size,
105-
use_packed=gguf_use_packed,
106-
cpu_offload=gguf_cpu_offload,
107-
desc_act=gguf_desc_act,
108-
desc_ten=gguf_desc_ten,
109-
legacy_format=gguf_legacy_format,
110-
batch_size=quantizer_batch_size,
44+
quant_type=quant_type,
45+
use_packed=use_packed,
46+
cpu_offload=cpu_offload,
47+
desc_act=desc_act,
48+
desc_ten=desc_ten,
49+
legacy_format=legacy_format,
50+
batch_size=batch_size,
11151
device=device
11252
)
113-
logger.log_info(f"Quantizing with GGUF... Bits: {bits}, Group Size: {group_size}, Packed: {gguf_use_packed}, CPU Offload: {gguf_cpu_offload}")
114-
quantizer.quantize(calibration_data=calibration_data) # Model updated in place
53+
logger.log_info("Starting quantization process")
54+
quantized_model = quantizer.quantize(calibration_data)
55+
memory_tracker.log_memory("quantization_complete")
56+
benchmark_results = {}
57+
if benchmark:
58+
logger.log_info("Running benchmarks")
59+
if not benchmark_input_shape:
60+
if hasattr(quantized_model.config, 'max_position_embeddings'):
61+
seq_len = min(32, quantized_model.config.max_position_embeddings)
62+
else:
63+
seq_len = 32
64+
benchmark_input_shape = (1, seq_len)
65+
benchmarker = QuantizationBenchmark(
66+
model=quantized_model,
67+
calibration_data=calibration_data,
68+
input_shape=benchmark_input_shape,
69+
num_inference_steps=benchmark_steps,
70+
device=device
71+
)
72+
benchmark_results = benchmarker.run_all_benchmarks()
73+
memory_tracker.log_memory("benchmarking_complete")
74+
logger.log_info("Benchmark Results:")
75+
if hasattr(benchmark_results, 'to_dict'):
76+
benchmark_results = benchmark_results.to_dict()
77+
for metric, value in (benchmark_results.items() if isinstance(benchmark_results, dict) else []):
78+
logger.log_info(f"{metric}: {value}")
79+
return quantized_model, benchmark_results
80+
except Exception as e:
81+
logger.log_error(f"Quantization failed: {str(e)}")
82+
raise
83+
finally:
84+
memory_tracker.clear_memory()
11585

116-
else:
117-
logger.log_error(f"Unsupported quantization method: {method}")
118-
raise ValueError(f"Unsupported quantization method: {method}. Supported methods are 'awq', 'gptq', 'gguf'.")
11986

120-
if quantizer is None or quantizer.model is None:
121-
logger.log_error(f"Failed to initialize quantizer or obtain quantized model for method: {method}")
122-
raise RuntimeError(f"Quantization failed for method: {method}. Quantizer or model is None.")
12387

124-
logger.log_info(f"Successfully quantized model with method: {method_lower}")
125-
return quantizer.model, quantizer.tokenizer
88+
@staticmethod
89+
def save_quantized_model(
90+
model: PreTrainedModel,
91+
output_path: str,
92+
save_tokenizer: bool = True
93+
):
94+
"""
95+
Save a quantized model and optionally its tokenizer.
96+
97+
Args:
98+
model: Quantized model to save
99+
output_path: Path to save the model
100+
save_tokenizer: Whether to save the tokenizer
101+
"""
102+
try:
103+
logger.log_info(f"Saving quantized model to {output_path}")
104+
memory_tracker.log_memory("save_start")
105+
106+
# Save model
107+
model.save_pretrained(output_path)
108+
109+
# Save tokenizer if requested
110+
if save_tokenizer and hasattr(model, 'config'):
111+
if hasattr(model.config, '_name_or_path'):
112+
try:
113+
tokenizer = AutoTokenizer.from_pretrained(
114+
model.config._name_or_path,
115+
trust_remote_code=True
116+
)
117+
tokenizer.save_pretrained(output_path)
118+
logger.log_info("Tokenizer saved successfully")
119+
except Exception as e:
120+
logger.log_warning(f"Failed to save tokenizer: {e}")
121+
122+
memory_tracker.log_memory("save_complete")
123+
logger.log_info("Model saved successfully")
124+
125+
except Exception as e:
126+
logger.log_error(f"Failed to save model: {str(e)}")
127+
raise
128+
finally:
129+
memory_tracker.clear_memory()
130+
131+
@staticmethod
132+
def convert_to_gguf(
133+
model: PreTrainedModel,
134+
output_path: str,
135+
quant_config: Optional[Dict[str, Any]] = None
136+
):
137+
"""
138+
Convert a quantized model to GGUF format.
139+
140+
Args:
141+
model: Model to convert
142+
output_path: Path to save GGUF file
143+
quant_config: Optional quantization configuration
144+
"""
145+
try:
146+
logger.log_info(f"Converting model to GGUF format: {output_path}")
147+
memory_tracker.log_memory("conversion_start")
148+
149+
# Get quantization config from model if not provided
150+
if not quant_config and hasattr(model.config, 'quantization_config'):
151+
quant_config = model.config.quantization_config
152+
153+
# Create quantizer with existing or default config
154+
quantizer = GGUFQuantizer(
155+
model_name=model,
156+
bits=quant_config.get('bits', 4) if quant_config else 4,
157+
group_size=quant_config.get('group_size', 128) if quant_config else 128,
158+
quant_type=quant_config.get('quant_type', None) if quant_config else None,
159+
use_packed=quant_config.get('use_packed', True) if quant_config else True
160+
)
161+
162+
# Convert to GGUF
163+
quantizer.convert_to_gguf(output_path)
164+
memory_tracker.log_memory("conversion_complete")
165+
logger.log_info("GGUF conversion completed successfully")
166+
167+
except Exception as e:
168+
logger.log_error(f"GGUF conversion failed: {str(e)}")
169+
raise
170+
finally:
171+
memory_tracker.clear_memory()

0 commit comments

Comments
 (0)