Skip to content

Commit afe8571

Browse files
Merge pull request #10 from codewithdark-git/feat/gguf-quantization-optim
feat: Isolate and optimize GGUF quantization
2 parents 94c431c + 1cdf6e9 commit afe8571

File tree

5 files changed

+2274
-569
lines changed

5 files changed

+2274
-569
lines changed

benchmark/run_benchmarks.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import torch
2+
import gc
3+
import argparse
4+
import pandas as pd
5+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
6+
from typing import List, Dict, Optional, Any
7+
8+
# Adjust import paths based on your project structure
9+
# Assuming quantllm is in the Python path or installed
10+
from quantllm.quant.gguf import GGUFQuantizer
11+
from quantllm.utils.benchmark import QuantizationBenchmark
12+
13+
DEFAULT_MODEL_LIST = ["facebook/opt-125m", "facebook/opt-350m"]
14+
DEFAULT_GGUF_CONFIGS = [
15+
{"name": "GGUF_B4_GS32_Packed", "bits": 4, "group_size": 32, "use_packed": True, "desc_act": False},
16+
{"name": "GGUF_B8_GS128_Packed", "bits": 8, "group_size": 128, "use_packed": True, "desc_act": False},
17+
{"name": "GGUF_B4_PerTensor_Packed", "bits": 4, "group_size": -1, "use_packed": True, "desc_act": False},
18+
{"name": "GGUF_B4_GS32_Packed_CPUOffload", "bits": 4, "group_size": 32, "use_packed": True, "desc_act": False, "cpu_offload": True},
19+
]
20+
21+
def _get_dummy_calibration_data(batch_size=1, seq_len=128, vocab_size=50257, num_samples=32) -> torch.Tensor:
22+
"""Generates random tensor for calibration data on CPU."""
23+
return torch.randint(0, vocab_size, (num_samples, seq_len), device='cpu')
24+
25+
def _load_model_and_tokenizer(model_name: str, trust_remote_code: bool = True) -> tuple[Optional[AutoModelForCausalLM], Optional[AutoTokenizer]]:
26+
"""Loads a Hugging Face model and tokenizer to CPU."""
27+
try:
28+
print(f"Loading model: {model_name}...")
29+
# Load model on CPU to manage memory before explicit placement by benchmark utility
30+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
31+
model = AutoModelForCausalLM.from_config(config, trust_remote_code=trust_remote_code).cpu()
32+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
33+
if tokenizer.pad_token is None:
34+
tokenizer.pad_token = tokenizer.eos_token
35+
print(f"Successfully loaded {model_name}.")
36+
return model.eval(), tokenizer # Ensure model is in eval mode
37+
except Exception as e:
38+
print(f"Error loading model {model_name}: {e}")
39+
return None, None
40+
41+
def run_gguf_benchmarks(
42+
model_names: Optional[List[str]] = None,
43+
gguf_configs: Optional[List[Dict[str, Any]]] = None,
44+
device: Optional[str] = None,
45+
num_inference_steps: int = 50,
46+
num_warmup_steps: int = 10,
47+
seq_len_for_calib: int = 128,
48+
num_calib_samples: int = 32
49+
):
50+
"""
51+
Runs GGUF quantization benchmarks using QuantizationBenchmark.
52+
53+
Args:
54+
model_names (Optional[List[str]]): List of Hugging Face model names.
55+
Defaults to DEFAULT_MODEL_LIST.
56+
gguf_configs (Optional[List[Dict[str, Any]]]): List of GGUF configurations to test.
57+
Each dict must include a "name" key for reporting.
58+
Defaults to DEFAULT_GGUF_CONFIGS.
59+
device (Optional[str]): Device to run benchmarks on ('cuda', 'cpu'). Auto-detects if None.
60+
num_inference_steps (int): Number of timed inference steps.
61+
num_warmup_steps (int): Number of warm-up inference steps.
62+
seq_len_for_calib (int): Sequence length for dummy calibration data.
63+
num_calib_samples (int): Number of samples for dummy calibration data.
64+
"""
65+
model_names = model_names if model_names else DEFAULT_MODEL_LIST
66+
gguf_configs = gguf_configs if gguf_configs else DEFAULT_GGUF_CONFIGS
67+
68+
all_results_summary = [] # To store DataFrames from each benchmark run
69+
70+
for model_name in model_names:
71+
print(f"\n{'='*20} Starting GGUF Benchmarks for Model: {model_name} {'='*20}")
72+
73+
original_model, tokenizer = _load_model_and_tokenizer(model_name)
74+
if original_model is None or tokenizer is None:
75+
print(f"Skipping benchmarks for {model_name} due to loading error.")
76+
continue
77+
78+
calibration_data = _get_dummy_calibration_data(
79+
vocab_size=original_model.config.vocab_size,
80+
seq_len=seq_len_for_calib,
81+
num_samples=num_calib_samples
82+
)
83+
84+
# Initialize benchmark utility for the current model
85+
# The model passed to QuantizationBenchmark is the original, unquantized model.
86+
# QuantizationBenchmark's _copy_model method will be used internally for each quantizer run.
87+
benchmark_utility = QuantizationBenchmark(
88+
model=original_model, # Original model, kept on CPU by benchmark init
89+
calibration_data=calibration_data, # Kept on CPU by benchmark init
90+
input_shape=(1, seq_len_for_calib), # (batch_size, seq_len) for inference tests
91+
num_inference_steps=num_inference_steps,
92+
# num_warmup_steps is not an __init__ arg for QuantizationBenchmark anymore
93+
device=device # Benchmark utility will handle device placement
94+
)
95+
96+
# Calculate original model size (parameters) in GB for efficiency metrics
97+
# Ensure model is on CPU for this calculation if not already guaranteed
98+
temp_model_cpu = original_model.to('cpu')
99+
original_model_size_gb = sum(
100+
p.numel() * p.element_size() for p in temp_model_cpu.parameters()
101+
) / (1024**3)
102+
del temp_model_cpu
103+
gc.collect()
104+
105+
106+
print(f"Original model '{model_name}' parameter size: {original_model_size_gb:.3f} GB")
107+
108+
for gguf_config_params in gguf_configs:
109+
config_name = gguf_config_params.get("name", f"GGUF_Custom_{gguf_config_params.get('bits','N')}b_GS{gguf_config_params.get('group_size','N')}")
110+
full_benchmark_name = f"{model_name}_{config_name}"
111+
112+
print(f"\n--- Benchmarking GGUF Configuration: {config_name} for {model_name} ---")
113+
114+
# Remove 'name' from args passed to quantizer, it's for reporting only
115+
quantizer_actual_args = {k: v for k, v in gguf_config_params.items() if k != "name"}
116+
117+
try:
118+
# benchmark_quantizer handles copying the model, quantizing, and running inference tests
119+
benchmark_utility.benchmark_quantizer(
120+
name=full_benchmark_name, # This name will be a key in benchmark_utility.results
121+
quantizer_class=GGUFQuantizer,
122+
quantizer_args=quantizer_actual_args,
123+
original_model_size_gb=original_model_size_gb,
124+
num_warmup_steps=num_warmup_steps # Pass num_warmup_steps here
125+
)
126+
except Exception as e:
127+
print(f"Error during benchmark for {full_benchmark_name}: {e}")
128+
# Store error in results if benchmark_quantizer didn't handle it internally
129+
if full_benchmark_name not in benchmark_utility.results:
130+
benchmark_utility.results[full_benchmark_name] = {"error": str(e)}
131+
else: # If benchmark_quantizer stored partials, add/update error
132+
benchmark_utility.results[full_benchmark_name]["error"] = str(e)
133+
134+
135+
# Print report for the current model after all its GGUF configs have been benchmarked
136+
print(f"\n--- Benchmark Report for Model: {model_name} ---")
137+
# benchmark_utility.print_report() will use benchmark_utility.results
138+
# which now contains all runs for *this specific model instance*
139+
benchmark_utility.print_report()
140+
141+
# Store the results DataFrame for this model
142+
# run_all_benchmarks inside print_report returns a DataFrame.
143+
# Here, we want the df from the current benchmark_utility instance.
144+
current_model_df = pd.DataFrame.from_dict(benchmark_utility.results, orient='index')
145+
current_model_df['model_name'] = model_name # Add model name for combined report
146+
all_results_summary.append(current_model_df)
147+
148+
# Clean up for the current model
149+
del original_model, tokenizer, calibration_data, benchmark_utility
150+
if torch.cuda.is_available():
151+
torch.cuda.empty_cache()
152+
gc.collect()
153+
print(f"\n{'='*20} Finished GGUF Benchmarks for Model: {model_name} {'='*20}")
154+
155+
if all_results_summary:
156+
final_summary_df = pd.concat(all_results_summary)
157+
print("\n\n===== Overall GGUF Benchmark Summary =====")
158+
# Re-format or select columns for the final summary if needed
159+
# For now, just print the concatenated DataFrame
160+
pd.set_option('display.max_rows', None)
161+
pd.set_option('display.max_columns', None)
162+
pd.set_option('display.width', 1000)
163+
print(final_summary_df)
164+
else:
165+
print("No benchmark results were collected.")
166+
167+
168+
def main():
169+
parser = argparse.ArgumentParser(description="Run GGUF Quantization Benchmarks.")
170+
parser.add_argument(
171+
"--model_names",
172+
type=str,
173+
nargs="+",
174+
default=DEFAULT_MODEL_LIST,
175+
help="List of Hugging Face model names to benchmark."
176+
)
177+
# Configs are defined in code for now, could be loaded from JSON/YAML in future
178+
parser.add_argument(
179+
"--device",
180+
type=str,
181+
default=None, # Auto-detect
182+
help="Device to run benchmarks on (e.g., 'cuda', 'cuda:0', 'cpu')."
183+
)
184+
parser.add_argument(
185+
"--num_inference_steps",
186+
type=int,
187+
default=50,
188+
help="Number of inference steps for latency/throughput measurement."
189+
)
190+
parser.add_argument(
191+
"--num_warmup_steps",
192+
type=int,
193+
default=10,
194+
help="Number of warm-up steps before timed inference."
195+
)
196+
parser.add_argument(
197+
"--seq_len_calib",
198+
type=int,
199+
default=128,
200+
help="Sequence length for dummy calibration data."
201+
)
202+
parser.add_argument(
203+
"--num_calib_samples",
204+
type=int,
205+
default=32,
206+
help="Number of samples for dummy calibration data."
207+
)
208+
209+
args = parser.parse_args()
210+
211+
print("Starting GGUF Benchmark Suite...")
212+
print(f"Models to benchmark: {args.model_names}")
213+
print(f"GGUF Configurations defined in code: {[c['name'] for c in DEFAULT_GGUF_CONFIGS]}")
214+
print(f"Device: {'Auto-detect' if args.device is None else args.device}")
215+
print(f"Inference Steps: {args.num_inference_steps}, Warm-up Steps: {args.num_warmup_steps}")
216+
217+
run_gguf_benchmarks(
218+
model_names=args.model_names,
219+
gguf_configs=DEFAULT_GGUF_CONFIGS, # Using the hardcoded default configs
220+
device=args.device,
221+
num_inference_steps=args.num_inference_steps,
222+
num_warmup_steps=args.num_warmup_steps,
223+
seq_len_for_calib=args.seq_len_calib,
224+
num_calib_samples=args.num_calib_samples
225+
)
226+
227+
print("\nGGUF Benchmark Suite Finished.")
228+
229+
if __name__ == "__main__":
230+
main()

0 commit comments

Comments
 (0)