@@ -137,6 +137,14 @@ def fasterprune(
137
137
if sparsity >= SPARSITY_THRESHOLD
138
138
else None
139
139
)
140
+
141
+ g_idx = []
142
+ if actorder :
143
+ g_idx = [perm [i ] // quant_scheme .weights .group_size for i in range (self .columns )]
144
+ g_idx = g_idx [invperm ]
145
+ else :
146
+ g_idx = [i // quant_scheme .weights .group_size for i in range (self .columns )]
147
+ g_idx = torch .tensor (g_idx , dtype = torch .int32 , device = W .device )
140
148
141
149
# See section 3.4 of https://arxiv.org/abs/2203.07259
142
150
for i1 in range (0 , self .columns , blocksize ):
@@ -148,6 +156,15 @@ def fasterprune(
148
156
Err1 = torch .zeros_like (W1 )
149
157
Losses1 = torch .zeros_like (W1 )
150
158
Hinv1 = Hinv [i1 :i2 , i1 :i2 ]
159
+
160
+ # """
161
+ # if not channel wise
162
+
163
+ # strategy = quant_scheme.weights.strategy
164
+ # if strategy is not QuantizationStrategy.CHANNEL:
165
+ # idx = i
166
+
167
+ # """
151
168
152
169
if sparsity >= SPARSITY_THRESHOLD :
153
170
tmp = (
@@ -176,6 +193,7 @@ def fasterprune(
176
193
else :
177
194
q = torch .quantize_per_channel (q , scale , zero_point , 0 , dtype )
178
195
q = torch .dequantize (q )
196
+
179
197
elif hasattr (self .layer , "quantization_scheme" ):
180
198
quant_scheme = self .layer .quantization_scheme
181
199
if quant_scheme .weights is not None :
@@ -235,9 +253,11 @@ def fasterprune(
235
253
236
254
_LOGGER .info ("time %.2f" % (time .time () - tick ))
237
255
_LOGGER .info ("error %.2f" % torch .sum (Losses ).item ())
238
-
256
+
257
+
239
258
if actorder :
240
259
W = W [:, invperm ]
260
+ # g_idx = g_idx[invperm]
241
261
242
262
if isinstance (self .layer , transformers .Conv1D ):
243
263
W = W .t ()
@@ -247,6 +267,7 @@ def fasterprune(
247
267
# place, clone() or direct assignment won't work
248
268
self .layer .weight -= self .layer .weight
249
269
self .layer .weight += W
270
+ self .g_idx = g_idx
250
271
251
272
def free (self ):
252
273
"""
0 commit comments