|
|
|
@@ -138,6 +138,7 @@ class QuantizationAwareTraining(Quantizer): |
|
|
|
The first element represents weights and the second element represents data flow. Default: (False, False) |
|
|
|
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only |
|
|
|
support QAT. Default: OptimizeOption.QAT |
|
|
|
one_conv_fold (bool): Flag to used one conv bn fold ops for simulation inference operation. Default: True. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> class LeNet5(nn.Cell): |
|
|
|
@@ -182,7 +183,8 @@ class QuantizationAwareTraining(Quantizer): |
|
|
|
per_channel=(False, False), |
|
|
|
symmetric=(False, False), |
|
|
|
narrow_range=(False, False), |
|
|
|
optimize_option=OptimizeOption.QAT): |
|
|
|
optimize_option=OptimizeOption.QAT, |
|
|
|
one_conv_fold=True): |
|
|
|
"""Init for QuantizationAwareTraining quantizer""" |
|
|
|
super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option) |
|
|
|
def convert2list(name, value): |
|
|
|
@@ -210,6 +212,7 @@ class QuantizationAwareTraining(Quantizer): |
|
|
|
self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") |
|
|
|
self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") |
|
|
|
self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") |
|
|
|
self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold") |
|
|
|
self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv, |
|
|
|
nn.DenseBnAct: self._convert_dense} |
|
|
|
self.quant_config = create_quant_config(quant_delay=quant_delay, |
|
|
|
@@ -300,22 +303,39 @@ class QuantizationAwareTraining(Quantizer): |
|
|
|
if subcell.has_bn: |
|
|
|
if self.bn_fold: |
|
|
|
bn_inner = subcell.batchnorm |
|
|
|
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, |
|
|
|
conv_inner.out_channels, |
|
|
|
kernel_size=conv_inner.kernel_size, |
|
|
|
stride=conv_inner.stride, |
|
|
|
pad_mode=conv_inner.pad_mode, |
|
|
|
padding=conv_inner.padding, |
|
|
|
dilation=conv_inner.dilation, |
|
|
|
group=conv_inner.group, |
|
|
|
eps=bn_inner.eps, |
|
|
|
momentum=bn_inner.momentum, |
|
|
|
has_bias=conv_inner.has_bias, |
|
|
|
bias_init=conv_inner.bias_init, |
|
|
|
freeze_bn=self.freeze_bn, |
|
|
|
quant_config=self.quant_config, |
|
|
|
quant_dtype=self.weight_dtype, |
|
|
|
fake=True) |
|
|
|
if self.one_conv_fold: |
|
|
|
conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels, |
|
|
|
conv_inner.out_channels, |
|
|
|
kernel_size=conv_inner.kernel_size, |
|
|
|
stride=conv_inner.stride, |
|
|
|
pad_mode=conv_inner.pad_mode, |
|
|
|
padding=conv_inner.padding, |
|
|
|
dilation=conv_inner.dilation, |
|
|
|
group=conv_inner.group, |
|
|
|
eps=bn_inner.eps, |
|
|
|
momentum=bn_inner.momentum, |
|
|
|
has_bias=conv_inner.has_bias, |
|
|
|
bias_init=conv_inner.bias_init, |
|
|
|
quant_config=self.quant_config, |
|
|
|
quant_dtype=self.weight_dtype, |
|
|
|
fake=True) |
|
|
|
else: |
|
|
|
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, |
|
|
|
conv_inner.out_channels, |
|
|
|
kernel_size=conv_inner.kernel_size, |
|
|
|
stride=conv_inner.stride, |
|
|
|
pad_mode=conv_inner.pad_mode, |
|
|
|
padding=conv_inner.padding, |
|
|
|
dilation=conv_inner.dilation, |
|
|
|
group=conv_inner.group, |
|
|
|
eps=bn_inner.eps, |
|
|
|
momentum=bn_inner.momentum, |
|
|
|
has_bias=conv_inner.has_bias, |
|
|
|
bias_init=conv_inner.bias_init, |
|
|
|
freeze_bn=self.freeze_bn, |
|
|
|
quant_config=self.quant_config, |
|
|
|
quant_dtype=self.weight_dtype, |
|
|
|
fake=True) |
|
|
|
# change original network BatchNormal OP parameters to quant network |
|
|
|
conv_inner.gamma = subcell.batchnorm.gamma |
|
|
|
conv_inner.beta = subcell.batchnorm.beta |
|
|
|
|