Skip to content

Commit 42cefff

Browse files
committed
black
1 parent 966a12b commit 42cefff

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

sparsebit/quantization/quant_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def prepare_calibration(self):
185185
from sparsebit.quantization.tools.calibration import CalibrationRunner
186186

187187
self.eval()
188-
self.calibration_runner = CalibrationRunner(self.model, self.cfg.SCHEDULE.BIAS_CORRECTION)
188+
self.calibration_runner = CalibrationRunner(
189+
self.model, self.cfg.SCHEDULE.BIAS_CORRECTION
190+
)
189191
self.calibration_runner.prepare_calibration()
190192

191193
def calc_qparams(self):

sparsebit/quantization/tools/calibration.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,27 +106,54 @@ def feature_layerwise_calibration(self, device):
106106
self.builder.storage.set_output(node.target, outputs)
107107

108108
if self.bias_correction:
109-
if isinstance(module, QuantOpr) and getattr(module, "weight_quantizer", None):
109+
if isinstance(module, QuantOpr) and getattr(
110+
module, "weight_quantizer", None
111+
):
110112
for inp_node in node.all_input_nodes:
111113
inp_tensors = self.builder.storage.get_output(inp_node.target)
112114
float_outputs = torch.Tensor([])
113115
quant_outputs = torch.Tensor([])
114-
float_outputs_cached = self.builder.storage.get_output(node.target)
116+
float_outputs_cached = self.builder.storage.get_output(
117+
node.target
118+
)
115119
for idx in range(batch_num):
116120
inp_tensor = inp_tensors[idx].cuda()
117121
with torch.no_grad():
118122
float_output = float_outputs_cached[idx]
119123
module.set_quant(True, False)
120124
quant_output = module(inp_tensor).cpu()
121125
module.set_quant(False, False)
122-
float_outputs = torch.cat((float_outputs, float_output.detach()), 0)
123-
quant_outputs = torch.cat((quant_outputs, quant_output.detach()), 0)
124-
float_output_mean = float_outputs.transpose(module.input_quantizer.qdesc._ch_axis,0).flatten(1).mean(-1)
125-
quant_output_mean = quant_outputs.transpose(module.input_quantizer.qdesc._ch_axis,0).flatten(1).mean(-1)
126+
float_outputs = torch.cat(
127+
(float_outputs, float_output.detach()), 0
128+
)
129+
quant_outputs = torch.cat(
130+
(quant_outputs, quant_output.detach()), 0
131+
)
132+
float_output_mean = (
133+
float_outputs.transpose(
134+
module.input_quantizer.qdesc._ch_axis, 0
135+
)
136+
.flatten(1)
137+
.mean(-1)
138+
)
139+
quant_output_mean = (
140+
quant_outputs.transpose(
141+
module.input_quantizer.qdesc._ch_axis, 0
142+
)
143+
.flatten(1)
144+
.mean(-1)
145+
)
126146
bias = quant_output_mean - float_output_mean
127147
if module.bias is None:
128-
module.bias = nn.Parameter(data=torch.zeros(module.weight.size(0), dtype=torch.float32, device=device), requires_grad=False)
129-
module.bias.data = module.bias.data-bias.cuda()
148+
module.bias = nn.Parameter(
149+
data=torch.zeros(
150+
module.weight.size(0),
151+
dtype=torch.float32,
152+
device=device,
153+
),
154+
requires_grad=False,
155+
)
156+
module.bias.data = module.bias.data - bias.cuda()
130157

131158
self.builder.storage.finish_node(node.target)
132159

0 commit comments

Comments
 (0)