@@ -106,27 +106,54 @@ def feature_layerwise_calibration(self, device):
106106 self .builder .storage .set_output (node .target , outputs )
107107
108108 if self .bias_correction :
109- if isinstance (module , QuantOpr ) and getattr (module , "weight_quantizer" , None ):
109+ if isinstance (module , QuantOpr ) and getattr (
110+ module , "weight_quantizer" , None
111+ ):
110112 for inp_node in node .all_input_nodes :
111113 inp_tensors = self .builder .storage .get_output (inp_node .target )
112114 float_outputs = torch .Tensor ([])
113115 quant_outputs = torch .Tensor ([])
114- float_outputs_cached = self .builder .storage .get_output (node .target )
116+ float_outputs_cached = self .builder .storage .get_output (
117+ node .target
118+ )
115119 for idx in range (batch_num ):
116120 inp_tensor = inp_tensors [idx ].cuda ()
117121 with torch .no_grad ():
118122 float_output = float_outputs_cached [idx ]
119123 module .set_quant (True , False )
120124 quant_output = module (inp_tensor ).cpu ()
121125 module .set_quant (False , False )
122- float_outputs = torch .cat ((float_outputs , float_output .detach ()), 0 )
123- quant_outputs = torch .cat ((quant_outputs , quant_output .detach ()), 0 )
124- float_output_mean = float_outputs .transpose (module .input_quantizer .qdesc ._ch_axis ,0 ).flatten (1 ).mean (- 1 )
125- quant_output_mean = quant_outputs .transpose (module .input_quantizer .qdesc ._ch_axis ,0 ).flatten (1 ).mean (- 1 )
126+ float_outputs = torch .cat (
127+ (float_outputs , float_output .detach ()), 0
128+ )
129+ quant_outputs = torch .cat (
130+ (quant_outputs , quant_output .detach ()), 0
131+ )
132+ float_output_mean = (
133+ float_outputs .transpose (
134+ module .input_quantizer .qdesc ._ch_axis , 0
135+ )
136+ .flatten (1 )
137+ .mean (- 1 )
138+ )
139+ quant_output_mean = (
140+ quant_outputs .transpose (
141+ module .input_quantizer .qdesc ._ch_axis , 0
142+ )
143+ .flatten (1 )
144+ .mean (- 1 )
145+ )
126146 bias = quant_output_mean - float_output_mean
127147 if module .bias is None :
128- module .bias = nn .Parameter (data = torch .zeros (module .weight .size (0 ), dtype = torch .float32 , device = device ), requires_grad = False )
129- module .bias .data = module .bias .data - bias .cuda ()
148+ module .bias = nn .Parameter (
149+ data = torch .zeros (
150+ module .weight .size (0 ),
151+ dtype = torch .float32 ,
152+ device = device ,
153+ ),
154+ requires_grad = False ,
155+ )
156+ module .bias .data = module .bias .data - bias .cuda ()
130157
131158 self .builder .storage .finish_node (node .target )
132159
0 commit comments