Skip to content

Commit 05a422a

Browse files
Add the GGUF for Quantization
1 parent 9514de1 commit 05a422a

File tree

12 files changed

+731
-1460
lines changed

12 files changed

+731
-1460
lines changed

quantllm/api/quantization.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""High-level API for model quantization."""
2+
3+
from typing import Optional, Dict, Any, Union, Tuple
4+
import torch
5+
from transformers import PreTrainedModel, AutoTokenizer
6+
from ..quant.gguf import GGUFQuantizer, SUPPORTED_GGUF_BITS, SUPPORTED_GGUF_TYPES
7+
from ..utils.logger import logger
8+
from ..utils.memory_tracker import memory_tracker
9+
from ..utils.benchmark import QuantizationBenchmark
10+
11+
class QuantizationAPI:
12+
"""High-level API for model quantization with GGUF support."""
13+
14+
@staticmethod
15+
def quantize_model(
16+
model_name_or_path: Union[str, PreTrainedModel],
17+
bits: int = 4,
18+
group_size: int = 128,
19+
quant_type: Optional[str] = None,
20+
use_packed: bool = True,
21+
cpu_offload: bool = False,
22+
desc_act: bool = False,
23+
desc_ten: bool = False,
24+
legacy_format: bool = False,
25+
batch_size: int = 4,
26+
device: Optional[str] = None,
27+
calibration_data: Optional[torch.Tensor] = None,
28+
benchmark: bool = True,
29+
benchmark_input_shape: Optional[Tuple[int, ...]] = None,
30+
benchmark_steps: int = 100
31+
) -> Tuple[PreTrainedModel, Dict[str, Any]]:
32+
"""
33+
Quantize a model using GGUF format with optional benchmarking.
34+
35+
Args:
36+
model_name_or_path: Model identifier or instance
37+
bits: Number of bits for quantization
38+
group_size: Size of quantization groups
39+
quant_type: GGUF quantization type
40+
use_packed: Whether to use packed format
41+
cpu_offload: Whether to offload to CPU during quantization
42+
desc_act: Whether to include activation descriptors
43+
desc_ten: Whether to include tensor descriptors
44+
legacy_format: Whether to use legacy format
45+
batch_size: Batch size for processing
46+
device: Device for quantization
47+
calibration_data: Optional calibration data
48+
benchmark: Whether to run benchmarks
49+
benchmark_input_shape: Shape for benchmark inputs
50+
benchmark_steps: Number of benchmark steps
51+
52+
Returns:
53+
Tuple of (quantized model, benchmark results)
54+
"""
55+
try:
56+
logger.log_info(f"Starting model quantization with {bits} bits")
57+
memory_tracker.log_memory("quantization_start")
58+
59+
# Validate parameters
60+
if bits not in SUPPORTED_GGUF_BITS:
61+
raise ValueError(f"Unsupported bits: {bits}. Supported values: {SUPPORTED_GGUF_BITS}")
62+
63+
if quant_type and quant_type not in SUPPORTED_GGUF_TYPES.get(bits, []):
64+
raise ValueError(f"Unsupported quant_type: {quant_type} for {bits} bits")
65+
66+
# Initialize quantizer
67+
quantizer = GGUFQuantizer(
68+
model_name=model_name_or_path,
69+
bits=bits,
70+
group_size=group_size,
71+
quant_type=quant_type,
72+
use_packed=use_packed,
73+
cpu_offload=cpu_offload,
74+
desc_act=desc_act,
75+
desc_ten=desc_ten,
76+
legacy_format=legacy_format,
77+
batch_size=batch_size,
78+
device=device
79+
)
80+
81+
# Perform quantization
82+
logger.log_info("Starting quantization process")
83+
quantized_model = quantizer.quantize(calibration_data)
84+
memory_tracker.log_memory("quantization_complete")
85+
86+
# Run benchmarks if requested
87+
benchmark_results = {}
88+
if benchmark:
89+
logger.log_info("Running benchmarks")
90+
if not benchmark_input_shape:
91+
# Default shape based on model config
92+
if hasattr(quantized_model.config, 'max_position_embeddings'):
93+
seq_len = min(32, quantized_model.config.max_position_embeddings)
94+
else:
95+
seq_len = 32
96+
benchmark_input_shape = (1, seq_len)
97+
98+
benchmarker = QuantizationBenchmark(
99+
model=quantized_model,
100+
calibration_data=calibration_data,
101+
input_shape=benchmark_input_shape,
102+
num_inference_steps=benchmark_steps,
103+
device=device
104+
)
105+
106+
benchmark_results = benchmarker.run_all_benchmarks()
107+
memory_tracker.log_memory("benchmarking_complete")
108+
109+
# Log benchmark summary
110+
logger.log_info("Benchmark Results:")
111+
for metric, value in benchmark_results.items():
112+
logger.log_info(f"{metric}: {value}")
113+
114+
return quantized_model, benchmark_results
115+
116+
except Exception as e:
117+
logger.log_error(f"Quantization failed: {str(e)}")
118+
raise
119+
finally:
120+
memory_tracker.clear_memory()
121+
122+
@staticmethod
123+
def save_quantized_model(
124+
model: PreTrainedModel,
125+
output_path: str,
126+
save_tokenizer: bool = True
127+
):
128+
"""
129+
Save a quantized model and optionally its tokenizer.
130+
131+
Args:
132+
model: Quantized model to save
133+
output_path: Path to save the model
134+
save_tokenizer: Whether to save the tokenizer
135+
"""
136+
try:
137+
logger.log_info(f"Saving quantized model to {output_path}")
138+
memory_tracker.log_memory("save_start")
139+
140+
# Save model
141+
model.save_pretrained(output_path)
142+
143+
# Save tokenizer if requested
144+
if save_tokenizer and hasattr(model, 'config'):
145+
if hasattr(model.config, '_name_or_path'):
146+
try:
147+
tokenizer = AutoTokenizer.from_pretrained(
148+
model.config._name_or_path,
149+
trust_remote_code=True
150+
)
151+
tokenizer.save_pretrained(output_path)
152+
logger.log_info("Tokenizer saved successfully")
153+
except Exception as e:
154+
logger.log_warning(f"Failed to save tokenizer: {e}")
155+
156+
memory_tracker.log_memory("save_complete")
157+
logger.log_info("Model saved successfully")
158+
159+
except Exception as e:
160+
logger.log_error(f"Failed to save model: {str(e)}")
161+
raise
162+
finally:
163+
memory_tracker.clear_memory()
164+
165+
@staticmethod
166+
def convert_to_gguf(
167+
model: PreTrainedModel,
168+
output_path: str,
169+
quant_config: Optional[Dict[str, Any]] = None
170+
):
171+
"""
172+
Convert a quantized model to GGUF format.
173+
174+
Args:
175+
model: Model to convert
176+
output_path: Path to save GGUF file
177+
quant_config: Optional quantization configuration
178+
"""
179+
try:
180+
logger.log_info(f"Converting model to GGUF format: {output_path}")
181+
memory_tracker.log_memory("conversion_start")
182+
183+
# Get quantization config from model if not provided
184+
if not quant_config and hasattr(model.config, 'quantization_config'):
185+
quant_config = model.config.quantization_config
186+
187+
# Create quantizer with existing or default config
188+
quantizer = GGUFQuantizer(
189+
model_name=model,
190+
bits=quant_config.get('bits', 4) if quant_config else 4,
191+
group_size=quant_config.get('group_size', 128) if quant_config else 128,
192+
quant_type=quant_config.get('quant_type', None) if quant_config else None,
193+
use_packed=quant_config.get('use_packed', True) if quant_config else True
194+
)
195+
196+
# Convert to GGUF
197+
quantizer.convert_to_gguf(output_path)
198+
memory_tracker.log_memory("conversion_complete")
199+
logger.log_info("GGUF conversion completed successfully")
200+
201+
except Exception as e:
202+
logger.log_error(f"GGUF conversion failed: {str(e)}")
203+
raise
204+
finally:
205+
memory_tracker.clear_memory()

quantllm/cli/commands.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
from ..training import FineTuningTrainer, ModelEvaluator
55
from ..config import ModelConfig, TrainingConfig, DatasetConfig
66
from ..runtime import DeviceManager
7-
from ..utils.monitoring import TrainingLogger
7+
from ..utils.logger import logger
88

99
def train(args: Namespace):
1010
"""Execute model training command."""
11-
logger = TrainingLogger()
1211
device_manager = DeviceManager()
1312

1413
try:
@@ -44,12 +43,11 @@ def train(args: Namespace):
4443
trainer.train()
4544

4645
except Exception as e:
47-
logger.error(f"Training failed: {str(e)}")
46+
logger.log_error(f"Training failed: {str(e)}")
4847
raise
4948

5049
def evaluate(args: Namespace):
5150
"""Execute model evaluation command."""
52-
logger = TrainingLogger()
5351
device_manager = DeviceManager()
5452

5553
try:
@@ -78,7 +76,7 @@ def evaluate(args: Namespace):
7876
evaluator.save_results(results, args.output_file)
7977

8078
except Exception as e:
81-
logger.error(f"Evaluation failed: {str(e)}")
79+
logger.log_error(f"Evaluation failed: {str(e)}")
8280
raise
8381

8482
def quantize(args: Namespace):

quantllm/config/config_manager.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,11 @@
22
import json
33
from typing import Dict, Any, Optional
44
from pathlib import Path
5-
from ..trainer.logger import TrainingLogger
5+
from ..utils.logger import logger
66

77
class ConfigManager:
8-
def __init__(self, logger: Optional[TrainingLogger] = None):
9-
"""
10-
Initialize the configuration manager.
11-
12-
Args:
13-
logger (TrainingLogger, optional): Logger instance
14-
"""
15-
self.logger = logger or TrainingLogger()
8+
def __init__(self):
9+
"""Initialize the configuration manager."""
1610
self.configs = {}
1711

1812
def load_config(self, config_path: str) -> Dict[str, Any]:
@@ -30,7 +24,7 @@ def load_config(self, config_path: str) -> Dict[str, Any]:
3024
if not config_path.exists():
3125
raise FileNotFoundError(f"Config file not found: {config_path}")
3226

33-
self.logger.log_info(f"Loading configuration from {config_path}")
27+
logger.log_info(f"Loading configuration from {config_path}")
3428

3529
if config_path.suffix in ['.yaml', '.yml']:
3630
with open(config_path, 'r') as f:
@@ -42,11 +36,11 @@ def load_config(self, config_path: str) -> Dict[str, Any]:
4236
raise ValueError(f"Unsupported config file format: {config_path.suffix}")
4337

4438
self.configs[config_path.stem] = config
45-
self.logger.log_info(f"Successfully loaded configuration: {config_path.stem}")
39+
logger.log_info(f"Successfully loaded configuration: {config_path.stem}")
4640
return config
4741

4842
except Exception as e:
49-
self.logger.log_error(f"Error loading configuration: {str(e)}")
43+
logger.log_error(f"Error loading configuration: {str(e)}")
5044
raise
5145

5246
def save_config(self, config: Dict[str, Any], config_path: str):
@@ -59,7 +53,7 @@ def save_config(self, config: Dict[str, Any], config_path: str):
5953
"""
6054
try:
6155
config_path = Path(config_path)
62-
self.logger.log_info(f"Saving configuration to {config_path}")
56+
logger.log_info(f"Saving configuration to {config_path}")
6357

6458
if config_path.suffix in ['.yaml', '.yml']:
6559
with open(config_path, 'w') as f:
@@ -71,10 +65,10 @@ def save_config(self, config: Dict[str, Any], config_path: str):
7165
raise ValueError(f"Unsupported config file format: {config_path.suffix}")
7266

7367
self.configs[config_path.stem] = config
74-
self.logger.log_info(f"Successfully saved configuration: {config_path.stem}")
68+
logger.log_info(f"Successfully saved configuration: {config_path.stem}")
7569

7670
except Exception as e:
77-
self.logger.log_error(f"Error saving configuration: {str(e)}")
71+
logger.log_error(f"Error saving configuration: {str(e)}")
7872
raise
7973

8074
def get_config(self, config_name: str) -> Dict[str, Any]:
@@ -103,12 +97,12 @@ def update_config(self, config_name: str, updates: Dict[str, Any]):
10397
raise KeyError(f"Configuration not found: {config_name}")
10498

10599
try:
106-
self.logger.log_info(f"Updating configuration: {config_name}")
100+
logger.log_info(f"Updating configuration: {config_name}")
107101
self.configs[config_name].update(updates)
108-
self.logger.log_info(f"Successfully updated configuration: {config_name}")
102+
logger.log_info(f"Successfully updated configuration: {config_name}")
109103

110104
except Exception as e:
111-
self.logger.log_error(f"Error updating configuration: {str(e)}")
105+
logger.log_error(f"Error updating configuration: {str(e)}")
112106
raise
113107

114108
def validate_config(self, config_name: str, schema: Dict[str, Any]) -> bool:
@@ -126,7 +120,7 @@ def validate_config(self, config_name: str, schema: Dict[str, Any]) -> bool:
126120
raise KeyError(f"Configuration not found: {config_name}")
127121

128122
try:
129-
self.logger.log_info(f"Validating configuration: {config_name}")
123+
logger.log_info(f"Validating configuration: {config_name}")
130124
config = self.configs[config_name]
131125

132126
# Basic schema validation
@@ -136,9 +130,9 @@ def validate_config(self, config_name: str, schema: Dict[str, Any]) -> bool:
136130
if not isinstance(config[key], value_type):
137131
raise TypeError(f"Invalid type for {key}: expected {value_type}, got {type(config[key])}")
138132

139-
self.logger.log_info(f"Configuration validation successful: {config_name}")
133+
logger.log_info(f"Configuration validation successful: {config_name}")
140134
return True
141135

142136
except Exception as e:
143-
self.logger.log_error(f"Configuration validation failed: {str(e)}")
137+
logger.log_error(f"Configuration validation failed: {str(e)}")
144138
raise

0 commit comments

Comments
 (0)