Skip to content

Commit 61b2b5d

Browse files
committed
replace is_perchanel with granularity
seems ok version seems ok version tested ok version disable input quantizer of silu and upsample rectified version black black fix bug fix bug black
1 parent f473aef commit 61b2b5d

File tree

16 files changed

+230
-92
lines changed

16 files changed

+230
-92
lines changed

sparsebit/quantization/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
class Granularity(Enum):
66
LAYERWISE = 0
77
CHANNELWISE = 1
8+
GROUPWISE = 2
89

910

1011
class QuantTarget(Enum):
@@ -44,6 +45,10 @@ def get_qscheme(qscheme):
4445
return torch.per_channel_symmetric
4546
if qscheme == "per-channel-affine":
4647
return torch.per_channel_affine
48+
if qscheme == "per-group-symmetric":
49+
return "per-group-symmetric"
50+
if qscheme == "per-group-affine":
51+
return "per-group-affine"
4752
raise TypeError(
4853
"only support a qscheme equals to per-[tensor/channel]-[affine/symmetric] , not {}".format(
4954
qscheme

sparsebit/quantization/observers/aciq.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,18 @@ def __init__(self, config, qdesc):
6363
self.gaus_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5)
6464

6565
def calc_laplace_minmax(self):
66-
if self.is_perchannel:
67-
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
66+
data = self.data_cache.get_data_for_calibration(self.granularity)
67+
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
6868
b = torch.mean(torch.abs(data - data.mean(1).unsqueeze(1)), dim=1)
69-
else:
70-
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE)
69+
elif self.granularity == Granularity.LAYERWISE:
7170
b = torch.mean(torch.abs(data - data.mean()))
71+
else:
72+
raise NotImplementedError
7273
self.data_cache.reset()
7374
is_half_range = data.min() >= 0
7475
if (
75-
self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine]
76+
self.qdesc.scheme
77+
in [torch.per_channel_affine, torch.per_tensor_affine, "per-group-affine"]
7678
and is_half_range
7779
):
7880
max_val = self.alpha_laplace_positive[self.qdesc.bit] * b
@@ -85,25 +87,26 @@ def calc_laplace_minmax(self):
8587
def calc_gaus_minmax(self):
8688
if self.qdesc.target == QuantTarget.FEATURE:
8789
batch_size = self.data_cache.get_batch_size()
88-
if self.is_perchannel:
89-
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
90+
data = self.data_cache.get_data_for_calibration(self.granularity)
91+
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
9092
max_val = data.max(axis=1).values
9193
min_val = data.min(axis=1).values
92-
else:
93-
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE)
94+
elif Granularity.LAYERWISE:
9495
max_val = data.max()
9596
min_val = data.min()
96-
self.data_cache.get_batch_size
97+
else:
98+
raise NotImplementedError
9799
self.data_cache.reset()
98100
is_half_range = data.min() >= 0
99-
num_elements = data.numel()
101+
num_elements = data[0].numel()
100102
if self.qdesc.target == QuantTarget.FEATURE:
101103
num_elements /= batch_size
102104
std = ((max_val - min_val) * self.gaus_const) / (
103105
(2 * math.log(num_elements)) ** 0.5
104106
)
105107
if (
106-
self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine]
108+
self.qdesc.scheme
109+
in [torch.per_channel_affine, torch.per_tensor_affine, "per-group-affine"]
107110
and is_half_range
108111
):
109112
max_val = self.alpha_gaus_positive[self.qdesc.bit] * std

sparsebit/quantization/observers/base.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ def __init__(self, qdesc):
1010
self._data_cache = []
1111

1212
def update(self, data):
13-
self._data_cache.append(data)
13+
if self.ch_axis != 0:
14+
self._data_cache.append(data.transpose(self.ch_axis, 0))
15+
else:
16+
self._data_cache.append(data)
1417

1518
def reset(self):
1619
self._data_cache = []
@@ -23,16 +26,42 @@ def get_data_for_calibration(self, granularity: Granularity):
2326
assert granularity in [
2427
Granularity.LAYERWISE,
2528
Granularity.CHANNELWISE,
26-
], "only layerwise or channelwise quantization are supported now!"
27-
if granularity == Granularity.CHANNELWISE:
28-
data = torch.cat(self._data_cache, dim=self.qdesc.ch_axis)
29-
if self.qdesc.ch_axis != 0:
30-
data = data.transpose(0, self.qdesc.ch_axis)
31-
data = data.flatten(1)
32-
elif granularity == Granularity.LAYERWISE:
33-
data = torch.cat([d.reshape(-1) for d in self._data_cache], axis=0)
34-
else:
35-
raise NotImplementedError
29+
Granularity.GROUPWISE,
30+
], "only layerwise, channelwise and groupwise quantization are supported now!"
31+
if granularity == Granularity.LAYERWISE:
32+
data = torch.cat([d.reshape(1, -1) for d in self._data_cache], axis=1)
33+
elif granularity == Granularity.CHANNELWISE:
34+
data = torch.cat(
35+
[d.reshape(d.shape[0], -1) for d in self._data_cache], axis=1
36+
)
37+
elif granularity == Granularity.GROUPWISE:
38+
if self.target == QuantTarget.FEATURE: # feature group on channel dim
39+
assert (
40+
self._data_cache[0].shape[0] <= self.group_size
41+
or self._data_cache[0].shape[0] % self.group_size == 0
42+
), "group size must be divided by channel num! got {} and {} instead".format(
43+
self.group_size, self._data_cache[0].shape[0]
44+
)
45+
group_num = max(self._data_cache[0].shape[0] // self.group_size, 1)
46+
if group_num == 1:
47+
self.qdesc.set_group_size = self._data_cache[0].shape[0]
48+
data = torch.cat(
49+
[d.reshape(group_num, -1) for d in self._data_cache], axis=1
50+
)
51+
else: # weight group on ic dim
52+
assert (
53+
self._data_cache[0].shape[1] <= self.group_size
54+
or self._data_cache[0].shape[1] % self.group_size == 0
55+
), "group size must be divided by ic num! got {} and {} instead".format(
56+
self.group_size, self._data_cache[0].shape[1]
57+
)
58+
group_num = max(self._data_cache[0].shape[1] // self.group_size, 1)
59+
if group_num == 1:
60+
self.qdesc.set_group_size = self._data_cache[0].shape[1]
61+
data = torch.cat(
62+
[d.reshape(d.shape[0] * group_num, -1) for d in self._data_cache],
63+
axis=1,
64+
)
3665
return data
3766

3867
def get_batch_size(self):
@@ -44,6 +73,18 @@ def get_data_cache(self):
4473
assert len(self._data_cache), "No data cached!"
4574
return self._data_cache
4675

76+
@property
77+
def target(self):
78+
return self.qdesc.target
79+
80+
@property
81+
def group_size(self):
82+
return self.qdesc.group_size
83+
84+
@property
85+
def ch_axis(self):
86+
return self.qdesc.ch_axis
87+
4788

4889
class Observer(nn.Module):
4990
def __init__(self, config, qdesc):
@@ -79,9 +120,17 @@ def calc_qparams_with_minmax(self, min_val, max_val):
79120
return scale, zero_point
80121

81122
@property
82-
def is_perchannel(self):
83-
return self.qdesc.is_perchannel
123+
def granularity(self):
124+
return self.qdesc.granularity
84125

85126
@property
86127
def is_symmetric(self):
87128
return self.qdesc.is_symmetric
129+
130+
@property
131+
def target(self):
132+
return self.qdesc.target
133+
134+
@property
135+
def group_size(self):
136+
return self.qdesc.group_size

sparsebit/quantization/observers/kl_histogram.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,8 @@ def __init__(self, config, qdesc):
102102
self.bins = 2048
103103

104104
def calc_minmax(self):
105-
if self.is_perchannel:
106-
data = self.data_cache.get_data_for_calibration(
107-
Granularity.CHANNELWISE
108-
).cpu()
105+
data = self.data_cache.get_data_for_calibration(self.granularity).cpu()
106+
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
109107
channel = data.shape[0]
110108
abs_max = data.abs().max(axis=1).values
111109
_min = torch.empty(channel)
@@ -131,8 +129,7 @@ def calc_minmax(self):
131129
_max[c] = th[c]
132130
self.max_val = _max.to(self.device)
133131
self.min_val = _min.to(self.device)
134-
else:
135-
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE).cpu()
132+
elif self.granularity == Granularity.LAYERWISE:
136133
abs_max = data.abs().max()
137134
th = get_best_threshold(
138135
data=data,
@@ -147,5 +144,7 @@ def calc_minmax(self):
147144
if data.min() < 0
148145
else torch.zeros(1).to(self.device)
149146
)
147+
else:
148+
raise NotImplementedError
150149
self.data_cache.reset()
151150
return self.min_val, self.max_val

sparsebit/quantization/observers/minmax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from sparsebit.quantization.observers import Observer as BaseObserver
33
from sparsebit.quantization.observers import register_observer
4-
from sparsebit.quantization.common import Granularity
4+
from sparsebit.quantization.common import Granularity, QuantTarget
55

66

77
@register_observer
@@ -12,13 +12,13 @@ def __init__(self, config, qdesc):
1212
super(Observer, self).__init__(config, qdesc)
1313

1414
def calc_minmax(self):
15-
if self.is_perchannel:
16-
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
15+
data = self.data_cache.get_data_for_calibration(self.granularity)
16+
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
1717
max_val = data.max(axis=1).values
1818
min_val = data.min(axis=1).values
19-
else:
20-
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE)
19+
else: # Granularity.LAYERWISE
2120
min_val, max_val = data.min(), data.max()
21+
2222
self.data_cache.reset()
2323
self.min_val = min_val.to(self.device)
2424
self.max_val = max_val.to(self.device)

sparsebit/quantization/observers/moving_average.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from sparsebit.quantization.observers import Observer as BaseObserver
33
from sparsebit.quantization.observers import register_observer
4-
from sparsebit.quantization.common import QuantTarget
4+
from sparsebit.quantization.common import Granularity, QuantTarget
55

66

77
@register_observer
@@ -14,6 +14,9 @@ def __init__(self, config, qdesc):
1414
hasattr(config.OBSERVER, "MOVING_AVERAGE")
1515
and self.qdesc.target == QuantTarget.FEATURE
1616
), "Moving_average observer only support feature observing!"
17+
assert (
18+
self.granularity == Granularity.LAYERWISE
19+
), "Moving_average observer only support layerwise quantization!"
1720
self.ema_ratio = config.OBSERVER.MOVING_AVERAGE.EMA_RATIO
1821

1922
def calc_minmax(self):

sparsebit/quantization/observers/mse.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,23 @@ def __init__(self, config, qdesc):
1616
self.alpha = config.OBSERVER.PERCENTILE.ALPHA
1717

1818
def calc_minmax(self, data_c_first):
19-
if self.is_perchannel:
19+
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
2020
max_val = data_c_first.max(axis=1).values
2121
min_val = data_c_first.min(axis=1).values
22-
else:
22+
elif self.granularity == Granularity.LAYERWISE:
2323
min_val, max_val = data_c_first.min(), data_c_first.max()
24+
else:
25+
raise NotImplementedError
2426
self.min_val = min_val.to(self.device)
2527
self.max_val = max_val.to(self.device)
2628
return self.min_val, self.max_val
2729

2830
def calc_qparams(self):
29-
data_c_first = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
31+
data_c_first = self.data_cache.get_data_for_calibration(self.granularity)
3032
self.data_cache.reset()
3133
min_val, max_val = self.calc_minmax(data_c_first)
3234
x_f = data_c_first.to(self.device)
33-
if self.is_perchannel:
35+
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
3436
best_scale = torch.tensor(
3537
[1.0 for _ in range(data_c_first.shape[0])], device=self.device
3638
)
@@ -40,24 +42,27 @@ def calc_qparams(self):
4042
loss_min = torch.tensor(
4143
[1e10 for _ in range(data_c_first.shape[0])], device=self.device
4244
)
43-
else:
45+
elif self.granularity == Granularity.LAYERWISE:
4446
best_scale, best_zero_point = None, None
4547
loss_min = 1e10
48+
else:
49+
raise NotImplementedError
4650
for i in range(80):
4751
cur_min_val = min_val * (1.0 - (i * 0.01))
4852
cur_max_val = max_val * (1.0 - (i * 0.01))
4953
scale, zero_point = self.calc_qparams_with_minmax(cur_min_val, cur_max_val)
5054
x_dq = STE.apply(x_f, scale, zero_point, self.qdesc, Backend.VIRTUAL)
51-
if self.is_perchannel:
52-
loss = mse_loss(x_f, x_dq, is_perchannel=True)
55+
loss = mse_loss(x_f, x_dq, self.granularity)
56+
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
5357
best_scale[loss < loss_min] = scale[loss < loss_min]
5458
best_zero_point[loss < loss_min] = zero_point[loss < loss_min]
5559
loss_min[loss < loss_min] = loss[loss < loss_min]
56-
else:
57-
loss = mse_loss(x_f, x_dq, is_perchannel=False)
60+
elif self.granularity == Granularity.LAYERWISE:
5861
if loss < loss_min:
5962
loss_min = loss
6063
best_scale = scale
6164
best_zero_point = zero_point
65+
else:
66+
raise NotImplementedError
6267
assert len(self.data_cache) == 0, "free data cache after calc_qparams"
6368
return best_scale, best_zero_point

sparsebit/quantization/observers/percentile.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,7 @@ def __init__(self, config, qdesc):
1414
self.alpha = config.OBSERVER.PERCENTILE.ALPHA
1515

1616
def calc_minmax(self):
17-
18-
if self.is_perchannel:
19-
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
20-
else:
21-
data = self.data_cache.get_data_for_calibration(
22-
Granularity.LAYERWISE
23-
).reshape(1, -1)
17+
data = self.data_cache.get_data_for_calibration(self.granularity)
2418
self.data_cache.reset()
2519
channel = data.shape[0]
2620

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
def mse_loss(pred, tgt, is_perchannel=False):
2-
if is_perchannel:
1+
from sparsebit.quantization.common import Granularity
2+
3+
4+
def mse_loss(pred, tgt, granularity: Granularity):
5+
if granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
36
return ((pred - tgt) ** 2).mean(-1)
4-
else:
7+
elif granularity == Granularity.LAYERWISE:
58
return ((pred - tgt) ** 2).mean()
9+
else:
10+
raise NotImplementedError

sparsebit/quantization/quant_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_C.W.QUANTIZER.TYPE = "uniform"
1919
_C.W.QUANTIZER.DISABLE = False
2020
_C.W.QUANTIZER.BIT = -1
21+
_C.W.QUANTIZER.GROUP_SIZE = -1
2122
_C.W.OBSERVER = CN()
2223
_C.W.OBSERVER.TYPE = "MINMAX" # "MINMAX"/"MSE"/"PERCENTILE"/"KL_HISTOGRAM"
2324
_C.W.OBSERVER.PERCENTILE = CN()
@@ -32,6 +33,7 @@
3233
_C.A.QUANTIZER.TYPE = "uniform"
3334
_C.A.QUANTIZER.DISABLE = False
3435
_C.A.QUANTIZER.BIT = -1
36+
_C.A.QUANTIZER.GROUP_SIZE = -1
3537
_C.A.QUANTIZER.PACT = CN()
3638
_C.A.QUANTIZER.PACT.ALPHA_VALUE = 10
3739
_C.A.OBSERVER = CN()

0 commit comments

Comments
 (0)