Skip to content

Commit 94c431c

Browse files
Merge pull request #9 from codewithdark-git/fix/awq-quantization-issues
Hi there! I've made some improvements to the AWQ quantization impleme…
2 parents f4e5b68 + 23f331e commit 94c431c

File tree

3 files changed

+229
-11
lines changed

3 files changed

+229
-11
lines changed

quantllm/quant/awq.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,30 @@ def quantize(
9090
self.logger.log_info(f"Processing layer: {name}")
9191

9292
# Get activation scale for this layer
93-
act_scale = self.act_scales.get(name)
93+
act_scale_list_or_tensor = self.act_scales.get(name)
94+
95+
if act_scale_list_or_tensor is not None:
96+
if isinstance(act_scale_list_or_tensor, list):
97+
if all(isinstance(t, torch.Tensor) for t in act_scale_list_or_tensor):
98+
# Average the list of tensors
99+
act_scale = torch.stack(act_scale_list_or_tensor).mean(dim=0)
100+
else:
101+
# Handle unexpected content in the list
102+
self.logger.log_error(f"Activation scales for {name} contain non-tensor elements. Quantization may be incorrect.")
103+
# Fallback: attempt to use the list directly if _quantize_layer can handle it, or create a default
104+
# For safety, creating a default scale here.
105+
act_scale = torch.ones(module.in_features, device=self.device_manager.primary_device)
106+
elif isinstance(act_scale_list_or_tensor, torch.Tensor):
107+
# If it's already a tensor (e.g., if averaging was done elsewhere or only one batch)
108+
act_scale = act_scale_list_or_tensor
109+
else:
110+
self.logger.log_error(f"Unexpected type for activation scales of {name}: {type(act_scale_list_or_tensor)}. Using default.")
111+
act_scale = torch.ones(module.in_features, device=self.device_manager.primary_device)
112+
else:
113+
self.logger.log_warning(f"No activation scales found for {name}. Using default scale of 1.0.")
114+
# module.in_features should correspond to the expected dimension of the scale
115+
act_scale = torch.ones(module.in_features, device=self.device_manager.primary_device)
116+
94117
quantized = self._quantize_layer(module, act_scale)
95118

96119
# Replace layer in model
@@ -135,10 +158,13 @@ def fn(module, input, output):
135158
# Handle both 2D and 3D inputs
136159
if len(x.shape) == 3:
137160
# For 3D input (batch_size, seq_len, hidden_size)
138-
scale = torch.max(torch.abs(x.view(-1, x.size(-1))))
161+
# Compute scales per hidden channel: (hidden_size,)
162+
scale = torch.amax(torch.abs(x), dim=[0, 1])
139163
else:
140-
scale = torch.max(torch.abs(x))
141-
# Store scale in our temporary dictionary
164+
# For 2D input (batch_size, hidden_size)
165+
# Compute scales per hidden channel: (hidden_size,)
166+
scale = torch.amax(torch.abs(x), dim=0)
167+
# Store scale tensor (moved to CPU) in our temporary dictionary
142168
batch_scales[name].append(scale.cpu())
143169
return fn
144170

@@ -150,6 +176,7 @@ def fn(module, input, output):
150176
with torch.no_grad():
151177
data_on_device = move_to_device(data, self.device_manager.primary_device)
152178
self.model(data_on_device)
179+
del data_on_device # Free memory after forward pass
153180

154181
# Remove hooks
155182
for handle in handles:
@@ -158,12 +185,12 @@ def fn(module, input, output):
158185
# Process the collected scales
159186
for name in batch_scales:
160187
if batch_scales[name]: # If we collected any scales for this layer
161-
scales_tensor = torch.stack(batch_scales[name])
162-
# If this is the first batch
188+
# If this is the first batch for this layer
163189
if name not in self.act_scales:
164190
self.act_scales[name] = []
165-
# Add the processed scales to our main storage
166-
self.act_scales[name].extend([s.item() for s in scales_tensor])
191+
# Extend the list of scale tensors for this layer
192+
# batch_scales[name] already contains CPU tensors
193+
self.act_scales[name].extend(batch_scales[name])
167194

168195
# Clean up
169196
del batch_scales
@@ -206,13 +233,43 @@ def _quantize_layer(
206233

207234
# Ensure act_scale is on the same device as W before division
208235
act_scale_on_device = move_to_device(act_scale, W.device)
209-
W = W / act_scale_on_device.view(1, -1)
236+
237+
try:
238+
W = W / act_scale_on_device.view(1, -1)
239+
except RuntimeError as e:
240+
error_message = (
241+
f"Failed to scale weights with activation scales in _quantize_layer.\n"
242+
f" Weight (W) shape: {W.shape}\n"
243+
f" Activation scale (act_scale_on_device) shape: {act_scale_on_device.shape}\n"
244+
f" Original error: {str(e)}"
245+
)
246+
self.logger.log_error(error_message)
247+
raise RuntimeError(error_message) from e
210248

211249
# Compute quantization scales per group
212250
# All computations for scales and zero_points should happen on target_device
213251
if self.group_size > 0:
252+
if W.shape[0] % self.group_size != 0:
253+
error_message = (
254+
f"Weight dimension {W.shape[0]} is not divisible by group_size {self.group_size} "
255+
f"in _quantize_layer for layer being processed."
256+
)
257+
self.logger.log_error(error_message)
258+
raise ValueError(error_message) # ValueError is more appropriate here
259+
214260
n_groups = W.shape[0] // self.group_size
215-
W_groups = W.view(n_groups, self.group_size, -1)
261+
try:
262+
W_groups = W.view(n_groups, self.group_size, -1)
263+
except RuntimeError as e:
264+
error_message = (
265+
f"Failed to create view for grouped weights in _quantize_layer.\n"
266+
f" Weight (W) shape: {W.shape}\n"
267+
f" Calculated n_groups: {n_groups}\n"
268+
f" Group size: {self.group_size}\n"
269+
f" Original error: {str(e)}"
270+
)
271+
self.logger.log_error(error_message)
272+
raise RuntimeError(error_message) from e
216273

217274
scales_list = [] # Renamed from scales to scales_list
218275
zero_points_list = [] if self.zero_point else None # Renamed
@@ -246,12 +303,14 @@ def _quantize_layer(
246303
# W, scales, zero_points are on target_device
247304
W_quant = torch.round(W * scales.view(-1, 1) - zero_points.view(-1, 1))
248305
W_quant = W_quant.to(torch.int8) # Cast to int8
306+
del W # Free memory for W as it's no longer needed
249307

250308
# Store quantized weights and parameters
251309
# quantized module and its buffers are already on target_device
252310
quantized.weight_quantized.copy_(W_quant) # W_quant is already on target_device and int8
253311
quantized.weight_scale.copy_(1.0 / scales) # scales is on target_device
254312
quantized.weight_zero_point.copy_(zero_points) # zero_points is on target_device
313+
del scales, zero_points # Free memory for scales and zero_points
255314

256315
# Store additional AWQ-specific information
257316
# Ensure act_scale is on the same device as the quantized layer's parameters

quantllm/quant/quantization_engine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,15 @@ def prepare_calibration_data(self, calibration_data: torch.Tensor) -> torch.Tens
572572

573573
return calibration_data
574574

575+
import gc # Moved import to top of file
576+
577+
# ... (other imports and code) ...
578+
579+
class BaseQuantizer:
580+
# ... (other methods) ...
581+
575582
def _clear_memory(self):
576583
"""Clear GPU memory and run garbage collection."""
577-
import gc
578584
gc.collect()
579585
if torch.cuda.is_available():
580586
torch.cuda.empty_cache()
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import unittest
2+
import torch
3+
import torch.nn as nn
4+
from typing import List
5+
6+
# Assuming AWQQuantizer and QuantizedLinear are accessible from this path
7+
# Adjust the import path based on your project structure if necessary
8+
from quantllm.quant.awq import AWQQuantizer
9+
from quantllm.quant.quantization_engine import QuantizedLinear, QuantizationConfig
10+
11+
# 1. Dummy Model Definition
12+
class DummyModel(nn.Module):
13+
def __init__(self, in_features, out_features, hidden_features):
14+
super().__init__()
15+
self.fc1 = nn.Linear(in_features, hidden_features)
16+
self.relu = nn.ReLU()
17+
self.fc2 = nn.Linear(hidden_features, out_features)
18+
19+
def forward(self, x):
20+
# If input is 3D (batch, seq, features), flatten sequence for linear layers
21+
original_shape = x.shape
22+
if x.ndim == 3:
23+
x = x.view(-1, original_shape[-1])
24+
25+
x = self.relu(self.fc1(x))
26+
x = self.fc2(x)
27+
28+
# Reshape back if original input was 3D
29+
if len(original_shape) == 3:
30+
x = x.view(original_shape[0], original_shape[1], -1)
31+
return x
32+
33+
class TestAWQQuantizer(unittest.TestCase):
34+
def setUp(self):
35+
self.in_features = 16
36+
self.hidden_features = 32 # Must be divisible by group_size if group_size is not -1 or 1
37+
self.out_features = 8
38+
self.seq_len = 10
39+
self.batch_size = 4
40+
41+
# Instantiate the dummy model
42+
self.model = DummyModel(self.in_features, self.out_features, self.hidden_features)
43+
self.model.eval() # Important for quantization
44+
45+
# Create dummy calibration data
46+
# Shape: (batch_size, seq_len, in_features) - typical for NLP tasks
47+
self.dummy_calibration_data = torch.randn(self.batch_size, self.seq_len, self.in_features)
48+
49+
# Device configuration
50+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51+
self.model.to(self.device)
52+
self.dummy_calibration_data = self.dummy_calibration_data.to(self.device)
53+
54+
def test_awq_scale_computation_and_application(self):
55+
# Instantiate AWQQuantizer
56+
# group_size chosen to be compatible with hidden_features for fc2's input
57+
# and in_features for fc1's input if group_size were applied to fc1's weights
58+
# For this test, group_size = -1 would also work if not testing grouping specifically for fc1.
59+
# Let's use a group_size that divides in_features for fc1 and hidden_features for fc2
60+
group_size = 16 # Divides self.in_features (16) and self.hidden_features (32)
61+
62+
quantizer = AWQQuantizer(
63+
model_name=self.model, # Pass the model instance
64+
bits=4,
65+
group_size=group_size,
66+
zero_point=True,
67+
device=self.device
68+
)
69+
70+
# --- Collect Activation Stats ---
71+
# Process a single batch for simplicity in checking act_scales structure
72+
# _collect_activation_stats expects a single batch
73+
quantizer._collect_activation_stats(self.dummy_calibration_data[0].unsqueeze(0)) # Pass one sample from batch
74+
75+
# Assert: Check quantizer.act_scales
76+
self.assertIn('fc1', quantizer.act_scales)
77+
self.assertIn('fc2', quantizer.act_scales)
78+
79+
# For fc1
80+
fc1_scales_list = quantizer.act_scales['fc1']
81+
self.assertIsInstance(fc1_scales_list, list)
82+
self.assertTrue(len(fc1_scales_list) > 0, "fc1_scales_list should not be empty")
83+
for scale_tensor in fc1_scales_list:
84+
self.assertIsInstance(scale_tensor, torch.Tensor)
85+
self.assertEqual(scale_tensor.ndim, 1)
86+
self.assertEqual(scale_tensor.shape[0], self.in_features) # in_features of fc1
87+
88+
# For fc2
89+
fc2_scales_list = quantizer.act_scales['fc2']
90+
self.assertIsInstance(fc2_scales_list, list)
91+
self.assertTrue(len(fc2_scales_list) > 0, "fc2_scales_list should not be empty")
92+
for scale_tensor in fc2_scales_list:
93+
self.assertIsInstance(scale_tensor, torch.Tensor)
94+
self.assertEqual(scale_tensor.ndim, 1)
95+
self.assertEqual(scale_tensor.shape[0], self.hidden_features) # in_features of fc2
96+
97+
# --- Quantize Layer (Focusing on Scale Application for fc1) ---
98+
layer_to_quantize_fc1 = self.model.fc1
99+
act_scale_list_fc1 = quantizer.act_scales.get('fc1')
100+
self.assertIsNotNone(act_scale_list_fc1)
101+
self.assertIsInstance(act_scale_list_fc1, list)
102+
103+
# Average the collected scales
104+
act_scale_tensor_fc1 = torch.stack(act_scale_list_fc1).mean(dim=0)
105+
106+
self.assertEqual(act_scale_tensor_fc1.ndim, 1)
107+
self.assertEqual(act_scale_tensor_fc1.shape[0], self.in_features)
108+
109+
try:
110+
quantized_layer_fc1 = quantizer._quantize_layer(layer_to_quantize_fc1, act_scale_tensor_fc1)
111+
except RuntimeError as e:
112+
self.fail(f"_quantize_layer raised RuntimeError unexpectedly: {e}")
113+
114+
self.assertIsInstance(quantized_layer_fc1, QuantizedLinear)
115+
116+
# --- Full Quantization and Forward Pass (Integration Check) ---
117+
# Re-instantiate quantizer for a clean full run or clear previous act_scales
118+
quantizer_full = AWQQuantizer(
119+
model_name=self.model, # Pass a new copy or re-initialize
120+
bits=4,
121+
group_size=group_size,
122+
zero_point=True,
123+
device=self.device
124+
)
125+
126+
try:
127+
# Use the full calibration dataset and specify steps (can be number of batches)
128+
quantized_model = quantizer_full.quantize(
129+
calibration_data=self.dummy_calibration_data,
130+
calibration_steps=self.batch_size
131+
)
132+
except Exception as e: # Catch any exception during full quantization
133+
self.fail(f"quantizer.quantize raised an exception unexpectedly: {e}")
134+
135+
self.assertIsNotNone(quantized_model)
136+
# Check if layers are replaced
137+
self.assertIsInstance(quantized_model.fc1, QuantizedLinear)
138+
self.assertIsInstance(quantized_model.fc2, QuantizedLinear)
139+
140+
# Perform a forward pass
141+
sample_input = self.dummy_calibration_data[0].unsqueeze(0) # Take one sample
142+
try:
143+
output = quantized_model(sample_input.to(self.device))
144+
except RuntimeError as e:
145+
self.fail(f"Forward pass on quantized_model raised RuntimeError unexpectedly: {e}")
146+
147+
# Assert output shape
148+
# Output shape should be (batch_size_sample, seq_len_sample, out_features)
149+
# For sample_input: (1, self.seq_len, self.out_features)
150+
self.assertEqual(output.shape, (1, self.seq_len, self.out_features))
151+
152+
if __name__ == '__main__':
153+
unittest.main()

0 commit comments

Comments
 (0)