Skip to content

Commit 966a12b

Browse files
committed
add bias correction
add bias correction rectify inprogressing resnet18 ok resnet18 ok deit ok deit ok
1 parent a665d03 commit 966a12b

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

examples/post_training_quantization/imagenet1k/deit/qconfig.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,8 @@ A:
1515
BIT: 8
1616
OBSERVER:
1717
TYPE: MINMAX
18-
LAYOUT: NCHW
18+
LAYOUT: NLC
19+
SPECIFIC: [{
20+
"patch_embed_proj": ["OBSERVER.LAYOUT", "NCHW"],
21+
"head": ["OBSERVER.LAYOUT", "NCHW"],
22+
}]

sparsebit/quantization/quant_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
_C.SCHEDULE = CN()
1111
_C.SCHEDULE.FUSE_BN = False # use ``with torch.no_grad()`` if it's enabled
12+
_C.SCHEDULE.BIAS_CORRECTION = False
1213
_C.SCHEDULE.BN_TUNING = False
1314
_C.SCHEDULE.DISABLE_UNNECESSARY_QUANT = True
1415

sparsebit/quantization/quant_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,13 @@ def prepare_calibration(self):
185185
from sparsebit.quantization.tools.calibration import CalibrationRunner
186186

187187
self.eval()
188-
self.calibration_runner = CalibrationRunner(self.model)
188+
self.calibration_runner = CalibrationRunner(self.model, self.cfg.SCHEDULE.BIAS_CORRECTION)
189189
self.calibration_runner.prepare_calibration()
190190

191191
def calc_qparams(self):
192192
assert hasattr(self, "calibration_runner"), "run self.prepare_calibration first"
193-
self.calibration_runner.feature_layerwise_calibration(self.device)
194193
self.calibration_runner.weight_calibration()
194+
self.calibration_runner.feature_layerwise_calibration(self.device)
195195
del self.calibration_runner
196196

197197
def init_QAT(self):

sparsebit/quantization/tools/calibration.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.nn as nn
23
from functools import partial
34

45
from sparsebit.quantization.modules import QuantOpr
@@ -7,8 +8,9 @@
78

89

910
class CalibrationRunner(object):
10-
def __init__(self, model):
11+
def __init__(self, model, bias_correction=False):
1112
self.model = fx_symbolic_trace(model)
13+
self.bias_correction = bias_correction
1214

1315
def prepare_calibration(self):
1416
input_names_cache = set(
@@ -102,6 +104,30 @@ def feature_layerwise_calibration(self, device):
102104
# more time for less cuda memory occupation
103105
outputs.append(to_cpu(module(*args, **kwargs)))
104106
self.builder.storage.set_output(node.target, outputs)
107+
108+
if self.bias_correction:
109+
if isinstance(module, QuantOpr) and getattr(module, "weight_quantizer", None):
110+
for inp_node in node.all_input_nodes:
111+
inp_tensors = self.builder.storage.get_output(inp_node.target)
112+
float_outputs = torch.Tensor([])
113+
quant_outputs = torch.Tensor([])
114+
float_outputs_cached = self.builder.storage.get_output(node.target)
115+
for idx in range(batch_num):
116+
inp_tensor = inp_tensors[idx].cuda()
117+
with torch.no_grad():
118+
float_output = float_outputs_cached[idx]
119+
module.set_quant(True, False)
120+
quant_output = module(inp_tensor).cpu()
121+
module.set_quant(False, False)
122+
float_outputs = torch.cat((float_outputs, float_output.detach()), 0)
123+
quant_outputs = torch.cat((quant_outputs, quant_output.detach()), 0)
124+
float_output_mean = float_outputs.transpose(module.input_quantizer.qdesc._ch_axis,0).flatten(1).mean(-1)
125+
quant_output_mean = quant_outputs.transpose(module.input_quantizer.qdesc._ch_axis,0).flatten(1).mean(-1)
126+
bias = quant_output_mean - float_output_mean
127+
if module.bias is None:
128+
module.bias = nn.Parameter(data=torch.zeros(module.weight.size(0), dtype=torch.float32, device=device), requires_grad=False)
129+
module.bias.data = module.bias.data-bias.cuda()
130+
105131
self.builder.storage.finish_node(node.target)
106132

107133
def weight_calibration(self):

0 commit comments

Comments
 (0)