Skip to content

Commit 828e185

Browse files
committed
add g_idx
1 parent d39e7f9 commit 828e185

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ def fasterprune(
137137
if sparsity >= SPARSITY_THRESHOLD
138138
else None
139139
)
140+
141+
g_idx = []
142+
if actorder:
143+
g_idx = [perm[i] // quant_scheme.weights.group_size for i in range(self.columns)]
144+
g_idx = g_idx[invperm]
145+
else:
146+
g_idx = [i // quant_scheme.weights.group_size for i in range(self.columns)]
147+
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=W.device)
140148

141149
# See section 3.4 of https://arxiv.org/abs/2203.07259
142150
for i1 in range(0, self.columns, blocksize):
@@ -148,6 +156,15 @@ def fasterprune(
148156
Err1 = torch.zeros_like(W1)
149157
Losses1 = torch.zeros_like(W1)
150158
Hinv1 = Hinv[i1:i2, i1:i2]
159+
160+
# """
161+
# if not channel wise
162+
163+
# strategy = quant_scheme.weights.strategy
164+
# if strategy is not QuantizationStrategy.CHANNEL:
165+
# idx = i
166+
167+
# """
151168

152169
if sparsity >= SPARSITY_THRESHOLD:
153170
tmp = (
@@ -176,6 +193,7 @@ def fasterprune(
176193
else:
177194
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
178195
q = torch.dequantize(q)
196+
179197
elif hasattr(self.layer, "quantization_scheme"):
180198
quant_scheme = self.layer.quantization_scheme
181199
if quant_scheme.weights is not None:
@@ -235,9 +253,11 @@ def fasterprune(
235253

236254
_LOGGER.info("time %.2f" % (time.time() - tick))
237255
_LOGGER.info("error %.2f" % torch.sum(Losses).item())
238-
256+
257+
239258
if actorder:
240259
W = W[:, invperm]
260+
# g_idx = g_idx[invperm]
241261

242262
if isinstance(self.layer, transformers.Conv1D):
243263
W = W.t()
@@ -247,6 +267,7 @@ def fasterprune(
247267
# place, clone() or direct assignment won't work
248268
self.layer.weight -= self.layer.weight
249269
self.layer.weight += W
270+
self.g_idx = g_idx
250271

251272
def free(self):
252273
"""

0 commit comments

Comments
 (0)