@@ -10,7 +10,10 @@ def __init__(self, qdesc):
1010 self ._data_cache = []
1111
1212 def update (self , data ):
13- self ._data_cache .append (data )
13+ if self .ch_axis != 0 :
14+ self ._data_cache .append (data .transpose (self .ch_axis , 0 ))
15+ else :
16+ self ._data_cache .append (data )
1417
1518 def reset (self ):
1619 self ._data_cache = []
@@ -23,16 +26,42 @@ def get_data_for_calibration(self, granularity: Granularity):
2326 assert granularity in [
2427 Granularity .LAYERWISE ,
2528 Granularity .CHANNELWISE ,
26- ], "only layerwise or channelwise quantization are supported now!"
27- if granularity == Granularity .CHANNELWISE :
28- data = torch .cat (self ._data_cache , dim = self .qdesc .ch_axis )
29- if self .qdesc .ch_axis != 0 :
30- data = data .transpose (0 , self .qdesc .ch_axis )
31- data = data .flatten (1 )
32- elif granularity == Granularity .LAYERWISE :
33- data = torch .cat ([d .reshape (- 1 ) for d in self ._data_cache ], axis = 0 )
34- else :
35- raise NotImplementedError
29+ Granularity .GROUPWISE ,
30+ ], "only layerwise, channelwise and groupwise quantization are supported now!"
31+ if granularity == Granularity .LAYERWISE :
32+ data = torch .cat ([d .reshape (1 , - 1 ) for d in self ._data_cache ], axis = 1 )
33+ elif granularity == Granularity .CHANNELWISE :
34+ data = torch .cat (
35+ [d .reshape (d .shape [0 ], - 1 ) for d in self ._data_cache ], axis = 1
36+ )
37+ elif granularity == Granularity .GROUPWISE :
38+ if self .target == QuantTarget .FEATURE : # feature group on channel dim
39+ assert (
40+ self ._data_cache [0 ].shape [0 ] <= self .group_size
41+ or self ._data_cache [0 ].shape [0 ] % self .group_size == 0
42+ ), "group size must be divided by channel num! got {} and {} instead" .format (
43+ self .group_size , self ._data_cache [0 ].shape [0 ]
44+ )
45+ group_num = max (self ._data_cache [0 ].shape [0 ] // self .group_size , 1 )
46+ if group_num == 1 :
47+ self .qdesc .set_group_size = self ._data_cache [0 ].shape [0 ]
48+ data = torch .cat (
49+ [d .reshape (group_num , - 1 ) for d in self ._data_cache ], axis = 1
50+ )
51+ else : # weight group on ic dim
52+ assert (
53+ self ._data_cache [0 ].shape [1 ] <= self .group_size
54+ or self ._data_cache [0 ].shape [1 ] % self .group_size == 0
55+ ), "group size must be divided by ic num! got {} and {} instead" .format (
56+ self .group_size , self ._data_cache [0 ].shape [1 ]
57+ )
58+ group_num = max (self ._data_cache [0 ].shape [1 ] // self .group_size , 1 )
59+ if group_num == 1 :
60+ self .qdesc .set_group_size = self ._data_cache [0 ].shape [1 ]
61+ data = torch .cat (
62+ [d .reshape (d .shape [0 ] * group_num , - 1 ) for d in self ._data_cache ],
63+ axis = 1 ,
64+ )
3665 return data
3766
3867 def get_batch_size (self ):
@@ -44,6 +73,18 @@ def get_data_cache(self):
4473 assert len (self ._data_cache ), "No data cached!"
4574 return self ._data_cache
4675
76+ @property
77+ def target (self ):
78+ return self .qdesc .target
79+
80+ @property
81+ def group_size (self ):
82+ return self .qdesc .group_size
83+
84+ @property
85+ def ch_axis (self ):
86+ return self .qdesc .ch_axis
87+
4788
4889class Observer (nn .Module ):
4990 def __init__ (self , config , qdesc ):
@@ -79,9 +120,17 @@ def calc_qparams_with_minmax(self, min_val, max_val):
79120 return scale , zero_point
80121
81122 @property
82- def is_perchannel (self ):
83- return self .qdesc .is_perchannel
123+ def granularity (self ):
124+ return self .qdesc .granularity
84125
85126 @property
86127 def is_symmetric (self ):
87128 return self .qdesc .is_symmetric
129+
130+ @property
131+ def target (self ):
132+ return self .qdesc .target
133+
134+ @property
135+ def group_size (self ):
136+ return self .qdesc .group_size
0 commit comments