# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """quantization aware.""" import copy import re import numpy as np import mindspore.context as context from ... import log as logger from ... import nn, ops from ..._checkparam import ParamValidator as validator from ..._checkparam import Rel from ...common import Tensor from ...common import dtype as mstype from ...common.api import _executor from ...nn.layer import quant from ...ops import functional as F from ...ops import operations as P from ...ops.operations import _inner_ops as inner from ...train import serialization from . import quant_utils _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, nn.ReLU6: quant.ActQuant, nn.Sigmoid: quant.ActQuant, nn.LeakyReLU: quant.LeakyReLUQuant, nn.HSigmoid: quant.HSigmoidQuant, nn.HSwish: quant.HSwishQuant} class _AddFakeQuantInput(nn.Cell): """ Add FakeQuant OP at input of the network. Only support one input case. """ def __init__(self, network, quant_delay=0): super(_AddFakeQuantInput, self).__init__(auto_prefix=False) self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) self.fake_quant_input.update_parameters_name('fake_quant_input.') self.network = network def construct(self, data): data = self.fake_quant_input(data) output = self.network(data) return output class _AddFakeQuantAfterSubCell(nn.Cell): """ Add FakeQuant OP after of the sub Cell. """ def __init__(self, subcell, **kwargs): super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) self.subcell = subcell self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, ema=True, num_bits=kwargs["num_bits"], quant_delay=kwargs["quant_delay"], per_channel=kwargs["per_channel"], symmetric=kwargs["symmetric"], narrow_range=kwargs["narrow_range"]) def construct(self, *data): output = self.subcell(*data) output = self.fake_quant_act(output) return output class ConvertToQuantNetwork: """ Convert network to quantization aware network """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] def __init__(self, **kwargs): self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE) self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE) self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"]) self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE) self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE) self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE) self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0]) self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1]) self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0]) self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1]) self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0]) self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1]) self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, quant.DenseBnAct: self._convert_dense} def _convert_op_name(self, name): pattern = re.compile(r'([A-Z]{1})') name_new = re.sub(pattern, r'_\1', name).lower() if name_new[0] == '_': name_new = name_new[1:] return name_new def run(self): self.network.update_cell_prefix() network = self._convert_subcells2quant(self.network) self.network.update_cell_type("quant") return network def _convert_subcells2quant(self, network): """ convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell """ cells = network.name_cells() change = False for name in cells: subcell = cells[name] if subcell == network: continue elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): prefix = subcell.param_prefix new_subcell = self._convert_method_map[type(subcell)](subcell) new_subcell.update_parameters_name(prefix + '.') network.insert_child_to_cell(name, new_subcell) change = True else: self._convert_subcells2quant(subcell) if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) # add FakeQuant OP after OP in while list add_list = [] for name in network.__dict__: if name[0] == '_': continue attr = network.__dict__[name] if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: add_list.append((name, attr)) for name, prim_op in add_list: prefix = name add_quant = _AddFakeQuantAfterSubCell(prim_op, num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) prefix = self._convert_op_name(prim_op.name) if network.param_prefix: prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) add_quant.update_parameters_name(prefix + '.') del network.__dict__[name] network.insert_child_to_cell(name, add_quant) return network def _convert_conv(self, subcell): """ convert Conv2d cell to quant cell """ conv_inner = subcell.conv if subcell.has_bn: if self.bn_fold: bn_inner = subcell.batchnorm conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, dilation=conv_inner.dilation, group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, quant_delay=self.weight_qdelay, freeze_bn=self.freeze_bn, per_channel=self.weight_channel, num_bits=self.weight_bits, fake=True, symmetric=self.weight_symmetric, narrow_range=self.weight_range, has_bias=conv_inner.has_bias, bias_init=conv_inner.bias_init) # change original network BatchNormal OP parameters to quant network conv_inner.gamma = subcell.batchnorm.gamma conv_inner.beta = subcell.batchnorm.beta conv_inner.moving_mean = subcell.batchnorm.moving_mean conv_inner.moving_variance = subcell.batchnorm.moving_variance del subcell.batchnorm subcell.batchnorm = None subcell.has_bn = False else: bn_inner = subcell.batchnorm conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, dilation=conv_inner.dilation, group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, quant_delay=self.weight_qdelay, per_channel=self.weight_channel, num_bits=self.weight_bits, symmetric=self.weight_symmetric, narrow_range=self.weight_range, has_bias=conv_inner.has_bias, bias_init=conv_inner.bias_init) # change original network BatchNormal OP parameters to quant network conv_inner.batchnorm.gamma = subcell.batchnorm.gamma conv_inner.batchnorm.beta = subcell.batchnorm.beta conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance del subcell.batchnorm subcell.batchnorm = None subcell.has_bn = False else: conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, dilation=conv_inner.dilation, group=conv_inner.group, has_bias=conv_inner.has_bias, quant_delay=self.weight_qdelay, per_channel=self.weight_channel, num_bits=self.weight_bits, symmetric=self.weight_symmetric, narrow_range=self.weight_range) # change original network Conv2D OP parameters to quant network conv_inner.weight = subcell.conv.weight if subcell.conv.has_bias: conv_inner.bias = subcell.conv.bias subcell.conv = conv_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) return subcell def _convert_dense(self, subcell): """ convert dense cell to combine dense cell """ dense_inner = subcell.dense dense_inner = quant.DenseQuant(dense_inner.in_channels, dense_inner.out_channels, has_bias=dense_inner.has_bias, num_bits=self.weight_bits, quant_delay=self.weight_qdelay, per_channel=self.weight_channel, symmetric=self.weight_symmetric, narrow_range=self.weight_range) # change original network Dense OP parameters to quant network dense_inner.weight = subcell.dense.weight if subcell.dense.has_bias: dense_inner.bias = subcell.dense.bias subcell.dense = dense_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) return subcell def _convert_activation(self, activation): act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: raise ValueError("Unsupported activation in auto quant: ", act_class) return _ACTIVATION_MAP[act_class](activation=activation, num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) class ExportToQuantInferNetwork: """ Convert quantization aware network to infer network. Args: network (Cell): MindSpore network API `convert_quant_network`. inputs (Tensor): Input tensors of the `quantization aware training network`. mean (int): Input data mean. Default: 127.5. std_dev (int, float): Input data variance. Default: 127.5. is_mindir (bool): Whether is MINDIR format. Default: False. Returns: Cell, Infer network. """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): network = validator.check_isinstance('network', network, (nn.Cell,)) # quantize for inputs: q = f / scale + zero_point # dequantize for outputs: f = (q - zero_point) * scale self.input_scale = 1 / std_dev self.input_zero_point = round(mean) self.data_type = mstype.int8 self.network = copy.deepcopy(network) self.all_parameters = {p.name: p for p in self.network.get_parameters()} self.get_inputs_table(inputs) self.mean = mean self.std_dev = std_dev self.is_mindir = is_mindir def get_inputs_table(self, inputs): """Get the support info for quant export.""" phase_name = 'export_quant' graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False) self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id) def run(self): """Start to convert.""" self.network.update_cell_prefix() network = self.network if isinstance(network, _AddFakeQuantInput): network = network.network network = self._convert_quant2deploy(network) return network def _get_quant_block(self, cell_core, activation, fake_quant_a_out): """convet network's quant subcell to deploy subcell""" # Calculate the scale and zero point w_minq_name = cell_core.fake_quant_weight.minq.name np_type = mstype.dtype_to_nptype(self.data_type) param_dict = dict() param_dict["filter_maxq"] = None param_dict["filter_minq"] = None param_dict["output_maxq"] = None param_dict["output_minq"] = None param_dict["input_maxq"] = None param_dict["input_minq"] = None param_dict["mean"] = self.mean param_dict["std_dev"] = self.std_dev param_dict["symmetric"] = fake_quant_a_out.symmetric if self.is_mindir: scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \ quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \ quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) else: scale_w, zp_w = quant_utils.scale_zp_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) scale_a_out, _ = quant_utils.scale_zp_from_fake_quant_cell(fake_quant_a_out, np_type) info = self.quant_info_table.get(w_minq_name, None) if info: fack_quant_a_in_op, minq_name = info if minq_name == 'input': scale_a_in, zp_a_in = self.input_scale, self.input_zero_point else: maxq = self.all_parameters[minq_name[:-4] + "maxq"] minq = self.all_parameters[minq_name] if self.is_mindir: scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type) else: scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type) else: logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") return None # Build the `Quant` `Dequant` op. # Quant only support perlayer version. Need check here. quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) scale_deq = scale_a_out * scale_w dequant_op = inner.Dequant() if isinstance(activation, _AddFakeQuantAfterSubCell): activation = activation.subcell elif hasattr(activation, "get_origin"): activation = activation.get_origin() # get the `weight` and `bias` weight = cell_core.weight.data.asnumpy() bias = None if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): if cell_core.has_bias: bias = cell_core.bias.data.asnumpy() elif isinstance(cell_core, quant.Conv2dBnFoldQuant): weight, bias = quant_utils.fold_batchnorm(weight, cell_core) elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) weight_b = weight bias_b = bias # apply the quant weight = quant_utils.weight2int(weight, scale_w, zp_w) if bias is not None: bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) # fuse parameter # |--------|47:40|--------|39:32|--------|31:0| # offset_w [8] shift_N [8] deq_scale [32] float32_deq_scale = scale_deq.astype(np.float32) uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32) scale_length = scale_deq.size # channel dequant_param = np.zeros(scale_length, dtype=np.uint64) for index in range(scale_length): dequant_param[index] += uint32_deq_scale[index] scale_deq = Tensor(dequant_param, mstype.uint64) # get op if isinstance(cell_core, quant.DenseQuant): op_core = P.MatMul() weight = np.transpose(weight) weight_b = np.transpose(weight_b) else: op_core = cell_core.conv weight = Tensor(weight, self.data_type) weight_b = Tensor(weight_b) if bias_b is not None: bias_b = Tensor(bias_b, mstype.float32) if self.is_mindir: block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict) else: block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) return block def _convert_quant2deploy(self, network): """Convert network's all quant subcell to deploy subcell.""" cells = network.name_cells() change = False for name in cells: subcell = cells[name] if subcell == network: continue cell_core = None fake_quant_act = None activation = None if isinstance(subcell, quant.Conv2dBnAct): cell_core = subcell.conv activation = subcell.activation fake_quant_act = activation.fake_quant_act elif isinstance(subcell, quant.DenseBnAct): cell_core = subcell.dense activation = subcell.activation fake_quant_act = activation.fake_quant_act if cell_core is not None: new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) if new_subcell: prefix = subcell.param_prefix new_subcell.update_parameters_name(prefix + '.') network.insert_child_to_cell(name, new_subcell) change = True elif isinstance(subcell, _AddFakeQuantAfterSubCell): op = subcell.subcell if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): network.__delattr__(name) network.__setattr__(name, op) change = True else: self._convert_quant2deploy(subcell) if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) return network def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='AIR'): """ Exports MindSpore quantization predict model to deploy with AIR. Args: network (Cell): MindSpore network produced by `convert_quant_network`. inputs (Tensor): Inputs of the `quantization aware training network`. file_name (str): File name of model to export. mean (int, float): Input data mean. Default: 127.5. std_dev (int, float): Input data variance. Default: 127.5. file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported quantization aware model. Default: 'AIR'. - AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of Ascend model. - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format for MindSpore models. Recommended suffix for output file is '.mindir'. """ supported_device = ["Ascend", "GPU"] supported_formats = ['AIR', 'MINDIR'] mean = validator.check_type("mean", mean, (int, float)) std_dev = validator.check_type("std_dev", std_dev, (int, float)) if context.get_context('device_target') not in supported_device: raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) if file_format not in supported_formats: raise ValueError('Illegal file format {}.'.format(file_format)) network.set_train(False) if file_format == "MINDIR": exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) else: exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs) deploy_net = exporter.run() serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) def convert_quant_network(network, bn_fold=True, freeze_bn=10000000, quant_delay=(0, 0), num_bits=(8, 8), per_channel=(False, False), symmetric=(False, False), narrow_range=(False, False) ): r""" Create quantization aware training network. Args: network (Cell): Obtain a pipeline through network for saving graph summary. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during eval. The first element represent weights and second element represent data flow. Default: (0, 0) num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first element represent weights and second element represent data flow. Default: (8, 8) per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represent weights and second element represent data flow. Default: (False, False) symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on symmetric otherwise base on asymmetric. The first element represent weights and second element represent data flow. Default: (False, False) narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. The first element represents weights and the second element represents data flow. Default: (False, False) Returns: Cell, Network which has change to quantization aware training network cell. """ support_device = ["Ascend", "GPU"] def convert2list(name, value): if not isinstance(value, list) and not isinstance(value, tuple): value = [value] elif len(value) > 2: raise ValueError("input `{}` len should less then 2".format(name)) return value quant_delay = convert2list("quant delay", quant_delay) num_bits = convert2list("num bits", num_bits) per_channel = convert2list("per channel", per_channel) symmetric = convert2list("symmetric", symmetric) narrow_range = convert2list("narrow range", narrow_range) if context.get_context('device_target') not in support_device: raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) net = ConvertToQuantNetwork(network=network, quant_delay=quant_delay, bn_fold=bn_fold, freeze_bn=freeze_bn, num_bits=num_bits, per_channel=per_channel, symmetric=symmetric, narrow_range=narrow_range) return net.run()