@@ -90,7 +90,30 @@ def quantize(
90
90
self .logger .log_info (f"Processing layer: { name } " )
91
91
92
92
# Get activation scale for this layer
93
- act_scale = self .act_scales .get (name )
93
+ act_scale_list_or_tensor = self .act_scales .get (name )
94
+
95
+ if act_scale_list_or_tensor is not None :
96
+ if isinstance (act_scale_list_or_tensor , list ):
97
+ if all (isinstance (t , torch .Tensor ) for t in act_scale_list_or_tensor ):
98
+ # Average the list of tensors
99
+ act_scale = torch .stack (act_scale_list_or_tensor ).mean (dim = 0 )
100
+ else :
101
+ # Handle unexpected content in the list
102
+ self .logger .log_error (f"Activation scales for { name } contain non-tensor elements. Quantization may be incorrect." )
103
+ # Fallback: attempt to use the list directly if _quantize_layer can handle it, or create a default
104
+ # For safety, creating a default scale here.
105
+ act_scale = torch .ones (module .in_features , device = self .device_manager .primary_device )
106
+ elif isinstance (act_scale_list_or_tensor , torch .Tensor ):
107
+ # If it's already a tensor (e.g., if averaging was done elsewhere or only one batch)
108
+ act_scale = act_scale_list_or_tensor
109
+ else :
110
+ self .logger .log_error (f"Unexpected type for activation scales of { name } : { type (act_scale_list_or_tensor )} . Using default." )
111
+ act_scale = torch .ones (module .in_features , device = self .device_manager .primary_device )
112
+ else :
113
+ self .logger .log_warning (f"No activation scales found for { name } . Using default scale of 1.0." )
114
+ # module.in_features should correspond to the expected dimension of the scale
115
+ act_scale = torch .ones (module .in_features , device = self .device_manager .primary_device )
116
+
94
117
quantized = self ._quantize_layer (module , act_scale )
95
118
96
119
# Replace layer in model
@@ -135,10 +158,13 @@ def fn(module, input, output):
135
158
# Handle both 2D and 3D inputs
136
159
if len (x .shape ) == 3 :
137
160
# For 3D input (batch_size, seq_len, hidden_size)
138
- scale = torch .max (torch .abs (x .view (- 1 , x .size (- 1 ))))
161
+ # Compute scales per hidden channel: (hidden_size,)
162
+ scale = torch .amax (torch .abs (x ), dim = [0 , 1 ])
139
163
else :
140
- scale = torch .max (torch .abs (x ))
141
- # Store scale in our temporary dictionary
164
+ # For 2D input (batch_size, hidden_size)
165
+ # Compute scales per hidden channel: (hidden_size,)
166
+ scale = torch .amax (torch .abs (x ), dim = 0 )
167
+ # Store scale tensor (moved to CPU) in our temporary dictionary
142
168
batch_scales [name ].append (scale .cpu ())
143
169
return fn
144
170
@@ -150,6 +176,7 @@ def fn(module, input, output):
150
176
with torch .no_grad ():
151
177
data_on_device = move_to_device (data , self .device_manager .primary_device )
152
178
self .model (data_on_device )
179
+ del data_on_device # Free memory after forward pass
153
180
154
181
# Remove hooks
155
182
for handle in handles :
@@ -158,12 +185,12 @@ def fn(module, input, output):
158
185
# Process the collected scales
159
186
for name in batch_scales :
160
187
if batch_scales [name ]: # If we collected any scales for this layer
161
- scales_tensor = torch .stack (batch_scales [name ])
162
- # If this is the first batch
188
+ # If this is the first batch for this layer
163
189
if name not in self .act_scales :
164
190
self .act_scales [name ] = []
165
- # Add the processed scales to our main storage
166
- self .act_scales [name ].extend ([s .item () for s in scales_tensor ])
191
+ # Extend the list of scale tensors for this layer
192
+ # batch_scales[name] already contains CPU tensors
193
+ self .act_scales [name ].extend (batch_scales [name ])
167
194
168
195
# Clean up
169
196
del batch_scales
@@ -206,13 +233,43 @@ def _quantize_layer(
206
233
207
234
# Ensure act_scale is on the same device as W before division
208
235
act_scale_on_device = move_to_device (act_scale , W .device )
209
- W = W / act_scale_on_device .view (1 , - 1 )
236
+
237
+ try :
238
+ W = W / act_scale_on_device .view (1 , - 1 )
239
+ except RuntimeError as e :
240
+ error_message = (
241
+ f"Failed to scale weights with activation scales in _quantize_layer.\n "
242
+ f" Weight (W) shape: { W .shape } \n "
243
+ f" Activation scale (act_scale_on_device) shape: { act_scale_on_device .shape } \n "
244
+ f" Original error: { str (e )} "
245
+ )
246
+ self .logger .log_error (error_message )
247
+ raise RuntimeError (error_message ) from e
210
248
211
249
# Compute quantization scales per group
212
250
# All computations for scales and zero_points should happen on target_device
213
251
if self .group_size > 0 :
252
+ if W .shape [0 ] % self .group_size != 0 :
253
+ error_message = (
254
+ f"Weight dimension { W .shape [0 ]} is not divisible by group_size { self .group_size } "
255
+ f"in _quantize_layer for layer being processed."
256
+ )
257
+ self .logger .log_error (error_message )
258
+ raise ValueError (error_message ) # ValueError is more appropriate here
259
+
214
260
n_groups = W .shape [0 ] // self .group_size
215
- W_groups = W .view (n_groups , self .group_size , - 1 )
261
+ try :
262
+ W_groups = W .view (n_groups , self .group_size , - 1 )
263
+ except RuntimeError as e :
264
+ error_message = (
265
+ f"Failed to create view for grouped weights in _quantize_layer.\n "
266
+ f" Weight (W) shape: { W .shape } \n "
267
+ f" Calculated n_groups: { n_groups } \n "
268
+ f" Group size: { self .group_size } \n "
269
+ f" Original error: { str (e )} "
270
+ )
271
+ self .logger .log_error (error_message )
272
+ raise RuntimeError (error_message ) from e
216
273
217
274
scales_list = [] # Renamed from scales to scales_list
218
275
zero_points_list = [] if self .zero_point else None # Renamed
@@ -246,12 +303,14 @@ def _quantize_layer(
246
303
# W, scales, zero_points are on target_device
247
304
W_quant = torch .round (W * scales .view (- 1 , 1 ) - zero_points .view (- 1 , 1 ))
248
305
W_quant = W_quant .to (torch .int8 ) # Cast to int8
306
+ del W # Free memory for W as it's no longer needed
249
307
250
308
# Store quantized weights and parameters
251
309
# quantized module and its buffers are already on target_device
252
310
quantized .weight_quantized .copy_ (W_quant ) # W_quant is already on target_device and int8
253
311
quantized .weight_scale .copy_ (1.0 / scales ) # scales is on target_device
254
312
quantized .weight_zero_point .copy_ (zero_points ) # zero_points is on target_device
313
+ del scales , zero_points # Free memory for scales and zero_points
255
314
256
315
# Store additional AWQ-specific information
257
316
# Ensure act_scale is on the same device as the quantized layer's parameters
0 commit comments