# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ Quantization aware training User can use quantization aware to train a model. MindSpore supports quantization aware training, which models quantization errors in both the forward and backward passes using fake-quantization operations. Note that the entire computation is carried out in floating point. At the end of quantization aware training, MindSpore provides conversion functions to convert the trained model into lower precision. """ import re import mindspore.context as context from ... import nn, ops from ..._checkparam import Validator, Rel from ...nn.layer import quant from ...ops import functional as F from ..common import QuantDtype from .quantizer import Quantizer, OptimizeOption __all__ = ["QuantizationAwareTraining"] _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, nn.ReLU6: quant.ActQuant, nn.Sigmoid: quant.ActQuant, nn.LeakyReLU: quant.LeakyReLUQuant, nn.HSigmoid: quant.HSigmoidQuant, nn.HSwish: quant.HSwishQuant} def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), quant_delay=(0, 0), quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), per_channel=(False, False), symmetric=(False, False), narrow_range=(False, False) ): r""" Configs the oberser type of weights and data flow with quant params. Args: quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent weights and second element represent data flow. Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during eval. The first element represent weights and second element represent data flow. Default: (0, 0) quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first element represent weights and second element represent data flow. Default: (QuantDtype.INT8, QuantDtype.INT8) per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represent weights and second element represent data flow. Default: (False, False) symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on symmetric otherwise base on asymmetric. The first element represent weights and second element represent data flow. Default: (False, False) narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. The first element represents weights and the second element represents data flow. Default: (False, False) Returns: QuantConfig, Contains the oberser type of weight and activation. """ weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], per_channel=per_channel[0], symmetric=symmetric[0], narrow_range=narrow_range[0]) act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], per_channel=per_channel[-1], symmetric=symmetric[-1], narrow_range=narrow_range[-1]) return quant.QuantConfig(weight=weight_observer, activation=act_observer) class _AddFakeQuantInput(nn.Cell): """ Add FakeQuant OP at input of the network. Only support one input case. """ def __init__(self, network, quant_delay=0): super(_AddFakeQuantInput, self).__init__(auto_prefix=False) self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) self.fake_quant_input.update_parameters_name('fake_quant_input.') self.network = network def construct(self, data): data = self.fake_quant_input(data) output = self.network(data) return output class _AddFakeQuantAfterSubCell(nn.Cell): """ Add FakeQuant OP after of the sub Cell. """ def __init__(self, subcell, **kwargs): super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) self.subcell = subcell self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, ema=True, quant_dtype=kwargs["quant_dtype"], quant_delay=kwargs["quant_delay"], per_channel=kwargs["per_channel"], symmetric=kwargs["symmetric"], narrow_range=kwargs["narrow_range"]) def construct(self, *data): output = self.subcell(*data) output = self.fake_quant_act(output) return output class QuantizationAwareTraining(Quantizer): r""" Quantizer for quantization aware training. Args: bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during eval. The first element represent weights and second element represent data flow. Default: (0, 0) quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first element represent weights and second element represent data flow. Default: (QuantDtype.INT8, QuantDtype.INT8) per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represent weights and second element represent data flow. Default: (False, False) symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on symmetric otherwise base on asymmetric. The first element represent weights and second element represent data flow. Default: (False, False) narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. The first element represents weights and the second element represents data flow. Default: (False, False) optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only support QAT. Default: OptimizeOption.QAT Examples: >>> class Net(nn.Cell): >>> def __init__(self, num_class=10, channel=1): >>> super(LeNet5, self).__init__() >>> self.type = "fusion" >>> self.num_class = num_class >>> >>> # change `nn.Conv2d` to `nn.Conv2dBnAct` >>> self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') >>> self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') >>> # change `nn.Dense` to `nn.DenseBnAct` >>> self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') >>> self.fc2 = nn.DenseBnAct(120, 84, activation='relu') >>> self.fc3 = nn.DenseBnAct(84, self.num_class) >>> >>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) >>> self.flatten = nn.Flatten() >>> >>> def construct(self, x): >>> x = self.conv1(x) >>> x = self.max_pool2d(x) >>> x = self.conv2(x) >>> x = self.max_pool2d(x) >>> x = self.flatten(x) >>> x = self.fc1(x) >>> x = self.fc2(x) >>> x = self.fc3(x) >>> return x >>> >>> net = Net() >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) >>> net_qat = quantizer.quantize(net) """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] def __init__(self, bn_fold=True, freeze_bn=10000000, quant_delay=(0, 0), quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), per_channel=(False, False), symmetric=(False, False), narrow_range=(False, False), optimize_option=OptimizeOption.QAT): """Init for QuantizationAwareTraining quantizer""" super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option) def convert2list(name, value): if not isinstance(value, list) and not isinstance(value, tuple): value = [value] elif len(value) > 2: raise ValueError("input `{}` len should less then 2".format(name)) return value quant_delay = convert2list("quant delay", quant_delay) quant_dtype = convert2list("quant dtype", quant_dtype) per_channel = convert2list("per channel", per_channel) symmetric = convert2list("symmetric", symmetric) narrow_range = convert2list("narrow range", narrow_range) self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay") self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay") self.bn_fold = Validator.check_bool(bn_fold, "bn fold") self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn") self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype) self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype) self.weight_channel = Validator.check_bool(per_channel[0], "per channel") self.act_channel = Validator.check_bool(per_channel[-1], "per channel") self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric") self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, quant.DenseBnAct: self._convert_dense} self.quant_config = get_quant_config(quant_delay=quant_delay, quant_dtype=quant_dtype, per_channel=per_channel, symmetric=symmetric, narrow_range=narrow_range) def _convert_op_name(self, name): pattern = re.compile(r'([A-Z]{1})') name_new = re.sub(pattern, r'_\1', name).lower() if name_new[0] == '_': name_new = name_new[1:] return name_new def quantize(self, network): """ Quant API to convert input network to a quantization aware training network Args: network (Cell): network to be quantized. Examples: >>> net = Net() >>> quantizer = QuantizationAwareTraining() >>> net_qat = quantizer.quantize(net) """ support_device = ["Ascend", "GPU"] if context.get_context('device_target') not in support_device: raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) if OptimizeOption.QAT in self.optimize_option: network.update_cell_prefix() network = self._convert_subcells2quant(network) network.update_cell_type("quant") return network def _convert_subcells2quant(self, network): """ convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell """ cells = network.name_cells() change = False for name in cells: subcell = cells[name] if subcell == network: continue elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): prefix = subcell.param_prefix new_subcell = self._convert_method_map[type(subcell)](subcell) new_subcell.update_parameters_name(prefix + '.') network.insert_child_to_cell(name, new_subcell) change = True else: self._convert_subcells2quant(subcell) if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) # add FakeQuant OP after OP in while list add_list = [] for name in network.__dict__: if name[0] == '_': continue attr = network.__dict__[name] if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: add_list.append((name, attr)) for name, prim_op in add_list: prefix = name add_quant = _AddFakeQuantAfterSubCell(prim_op, quant_dtype=self.act_dtype, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) prefix = self._convert_op_name(prim_op.name) if network.param_prefix: prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) add_quant.update_parameters_name(prefix + '.') del network.__dict__[name] network.insert_child_to_cell(name, add_quant) return network def _convert_conv(self, subcell): """ convert Conv2d cell to quant cell """ conv_inner = subcell.conv if subcell.has_bn: if self.bn_fold: bn_inner = subcell.batchnorm conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, dilation=conv_inner.dilation, group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, has_bias=conv_inner.has_bias, bias_init=conv_inner.bias_init, freeze_bn=self.freeze_bn, quant_config=self.quant_config, quant_dtype=self.weight_dtype, fake=True) # change original network BatchNormal OP parameters to quant network conv_inner.gamma = subcell.batchnorm.gamma conv_inner.beta = subcell.batchnorm.beta conv_inner.moving_mean = subcell.batchnorm.moving_mean conv_inner.moving_variance = subcell.batchnorm.moving_variance del subcell.batchnorm subcell.batchnorm = None subcell.has_bn = False else: bn_inner = subcell.batchnorm conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, dilation=conv_inner.dilation, group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, has_bias=conv_inner.has_bias, bias_init=conv_inner.bias_init, quant_config=self.quant_config, quant_dtype=self.weight_dtype) # change original network BatchNormal OP parameters to quant network conv_inner.batchnorm.gamma = subcell.batchnorm.gamma conv_inner.batchnorm.beta = subcell.batchnorm.beta conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance del subcell.batchnorm subcell.batchnorm = None subcell.has_bn = False else: conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, dilation=conv_inner.dilation, group=conv_inner.group, has_bias=conv_inner.has_bias, quant_config=self.quant_config, quant_dtype=self.weight_dtype) # change original network Conv2D OP parameters to quant network conv_inner.weight = subcell.conv.weight if subcell.conv.has_bias: conv_inner.bias = subcell.conv.bias subcell.conv = conv_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) return subcell def _convert_dense(self, subcell): """ convert dense cell to combine dense cell """ dense_inner = subcell.dense dense_inner = quant.DenseQuant(dense_inner.in_channels, dense_inner.out_channels, has_bias=dense_inner.has_bias, quant_config=self.quant_config, quant_dtype=self.weight_dtype) # change original network Dense OP parameters to quant network dense_inner.weight = subcell.dense.weight if subcell.dense.has_bias: dense_inner.bias = subcell.dense.bias subcell.dense = dense_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) return subcell def _convert_activation(self, activation): act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: raise ValueError("Unsupported activation in auto quant: ", act_class) return _ACTIVATION_MAP[act_class](activation=activation, quant_config=self.quant_config, quant_dtype=self.act_dtype)