Merge pull request !7463 from yuchaojie/quant2tags/v1.1.0
| @@ -0,0 +1,17 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Compression export module. | |||||
| """ | |||||
| @@ -12,323 +12,28 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """quantization aware.""" | |||||
| """Export for quantization.""" | |||||
| import copy | import copy | ||||
| import re | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from ... import log as logger | from ... import log as logger | ||||
| from ... import nn, ops | from ... import nn, ops | ||||
| from ..._checkparam import Validator, Rel | |||||
| from ..._checkparam import Validator | |||||
| from ...common import Tensor | from ...common import Tensor | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.api import _executor | from ...common.api import _executor | ||||
| from ...nn.layer import quant | from ...nn.layer import quant | ||||
| from ...compression.common import QuantDtype | |||||
| from ...ops import functional as F | |||||
| from ...ops import operations as P | from ...ops import operations as P | ||||
| from ...ops.operations import _inner_ops as inner | from ...ops.operations import _inner_ops as inner | ||||
| from ...train import serialization | from ...train import serialization | ||||
| from . import quant_utils | |||||
| _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, | |||||
| nn.ReLU6: quant.ActQuant, | |||||
| nn.Sigmoid: quant.ActQuant, | |||||
| nn.LeakyReLU: quant.LeakyReLUQuant, | |||||
| nn.HSigmoid: quant.HSigmoidQuant, | |||||
| nn.HSwish: quant.HSwishQuant} | |||||
| def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), | |||||
| quant_delay=(0, 0), | |||||
| quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), | |||||
| per_channel=(False, False), | |||||
| symmetric=(False, False), | |||||
| narrow_range=(False, False) | |||||
| ): | |||||
| r""" | |||||
| Configs the oberser type of weights and data flow with quant params. | |||||
| from ..quant import quant_utils | |||||
| from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell | |||||
| Args: | |||||
| quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent | |||||
| weights and second element represent data flow. | |||||
| Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) | |||||
| quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during | |||||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||||
| quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first | |||||
| element represent weights and second element represent data flow. | |||||
| Default: (QuantDtype.INT8, QuantDtype.INT8) | |||||
| per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` | |||||
| then base on per channel otherwise base on per layer. The first element represent weights | |||||
| and second element represent data flow. Default: (False, False) | |||||
| symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on | |||||
| symmetric otherwise base on asymmetric. The first element represent weights and second | |||||
| element represent data flow. Default: (False, False) | |||||
| narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. | |||||
| The first element represents weights and the second element represents data flow. Default: (False, False) | |||||
| Returns: | |||||
| QuantConfig, Contains the oberser type of weight and activation. | |||||
| """ | |||||
| weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], | |||||
| per_channel=per_channel[0], symmetric=symmetric[0], | |||||
| narrow_range=narrow_range[0]) | |||||
| act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], | |||||
| per_channel=per_channel[-1], symmetric=symmetric[-1], | |||||
| narrow_range=narrow_range[-1]) | |||||
| return quant.QuantConfig(weight=weight_observer, activation=act_observer) | |||||
| class _AddFakeQuantInput(nn.Cell): | |||||
| """ | |||||
| Add FakeQuant OP at input of the network. Only support one input case. | |||||
| """ | |||||
| def __init__(self, network, quant_delay=0): | |||||
| super(_AddFakeQuantInput, self).__init__(auto_prefix=False) | |||||
| self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, | |||||
| quant_delay=quant_delay, ema=True) | |||||
| self.fake_quant_input.update_parameters_name('fake_quant_input.') | |||||
| self.network = network | |||||
| def construct(self, data): | |||||
| data = self.fake_quant_input(data) | |||||
| output = self.network(data) | |||||
| return output | |||||
| class _AddFakeQuantAfterSubCell(nn.Cell): | |||||
| """ | |||||
| Add FakeQuant OP after of the sub Cell. | |||||
| """ | |||||
| def __init__(self, subcell, **kwargs): | |||||
| super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) | |||||
| self.subcell = subcell | |||||
| self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6, | |||||
| max_init=6, | |||||
| ema=True, | |||||
| quant_dtype=kwargs["quant_dtype"], | |||||
| quant_delay=kwargs["quant_delay"], | |||||
| per_channel=kwargs["per_channel"], | |||||
| symmetric=kwargs["symmetric"], | |||||
| narrow_range=kwargs["narrow_range"]) | |||||
| def construct(self, *data): | |||||
| output = self.subcell(*data) | |||||
| output = self.fake_quant_act(output) | |||||
| return output | |||||
| class ConvertToQuantNetwork: | |||||
| """ | |||||
| Convert network to quantization aware network | |||||
| """ | |||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||||
| def __init__(self, **kwargs): | |||||
| self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||||
| self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") | |||||
| self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") | |||||
| self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") | |||||
| self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") | |||||
| self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype) | |||||
| self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype) | |||||
| self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") | |||||
| self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") | |||||
| self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") | |||||
| self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric") | |||||
| self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range") | |||||
| self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") | |||||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | |||||
| quant.DenseBnAct: self._convert_dense} | |||||
| self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"], | |||||
| quant_dtype=kwargs["quant_dtype"], | |||||
| per_channel=kwargs["per_channel"], | |||||
| symmetric=kwargs["symmetric"], | |||||
| narrow_range=kwargs["narrow_range"]) | |||||
| def _convert_op_name(self, name): | |||||
| pattern = re.compile(r'([A-Z]{1})') | |||||
| name_new = re.sub(pattern, r'_\1', name).lower() | |||||
| if name_new[0] == '_': | |||||
| name_new = name_new[1:] | |||||
| return name_new | |||||
| def run(self): | |||||
| self.network.update_cell_prefix() | |||||
| network = self._convert_subcells2quant(self.network) | |||||
| self.network.update_cell_type("quant") | |||||
| return network | |||||
| def _convert_subcells2quant(self, network): | |||||
| """ | |||||
| convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell | |||||
| """ | |||||
| cells = network.name_cells() | |||||
| change = False | |||||
| for name in cells: | |||||
| subcell = cells[name] | |||||
| if subcell == network: | |||||
| continue | |||||
| elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): | |||||
| prefix = subcell.param_prefix | |||||
| new_subcell = self._convert_method_map[type(subcell)](subcell) | |||||
| new_subcell.update_parameters_name(prefix + '.') | |||||
| network.insert_child_to_cell(name, new_subcell) | |||||
| change = True | |||||
| else: | |||||
| self._convert_subcells2quant(subcell) | |||||
| if isinstance(network, nn.SequentialCell) and change: | |||||
| network.cell_list = list(network.cells()) | |||||
| # add FakeQuant OP after OP in while list | |||||
| add_list = [] | |||||
| for name in network.__dict__: | |||||
| if name[0] == '_': | |||||
| continue | |||||
| attr = network.__dict__[name] | |||||
| if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: | |||||
| add_list.append((name, attr)) | |||||
| for name, prim_op in add_list: | |||||
| prefix = name | |||||
| add_quant = _AddFakeQuantAfterSubCell(prim_op, | |||||
| quant_dtype=self.act_dtype, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| prefix = self._convert_op_name(prim_op.name) | |||||
| if network.param_prefix: | |||||
| prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) | |||||
| add_quant.update_parameters_name(prefix + '.') | |||||
| del network.__dict__[name] | |||||
| network.insert_child_to_cell(name, add_quant) | |||||
| return network | |||||
| def _convert_conv(self, subcell): | |||||
| """ | |||||
| convert Conv2d cell to quant cell | |||||
| """ | |||||
| conv_inner = subcell.conv | |||||
| if subcell.has_bn: | |||||
| if self.bn_fold: | |||||
| bn_inner = subcell.batchnorm | |||||
| conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, | |||||
| conv_inner.out_channels, | |||||
| kernel_size=conv_inner.kernel_size, | |||||
| stride=conv_inner.stride, | |||||
| pad_mode=conv_inner.pad_mode, | |||||
| padding=conv_inner.padding, | |||||
| dilation=conv_inner.dilation, | |||||
| group=conv_inner.group, | |||||
| eps=bn_inner.eps, | |||||
| momentum=bn_inner.momentum, | |||||
| has_bias=conv_inner.has_bias, | |||||
| bias_init=conv_inner.bias_init, | |||||
| freeze_bn=self.freeze_bn, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype, | |||||
| fake=True) | |||||
| # change original network BatchNormal OP parameters to quant network | |||||
| conv_inner.gamma = subcell.batchnorm.gamma | |||||
| conv_inner.beta = subcell.batchnorm.beta | |||||
| conv_inner.moving_mean = subcell.batchnorm.moving_mean | |||||
| conv_inner.moving_variance = subcell.batchnorm.moving_variance | |||||
| del subcell.batchnorm | |||||
| subcell.batchnorm = None | |||||
| subcell.has_bn = False | |||||
| else: | |||||
| bn_inner = subcell.batchnorm | |||||
| conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, | |||||
| conv_inner.out_channels, | |||||
| kernel_size=conv_inner.kernel_size, | |||||
| stride=conv_inner.stride, | |||||
| pad_mode=conv_inner.pad_mode, | |||||
| padding=conv_inner.padding, | |||||
| dilation=conv_inner.dilation, | |||||
| group=conv_inner.group, | |||||
| eps=bn_inner.eps, | |||||
| momentum=bn_inner.momentum, | |||||
| has_bias=conv_inner.has_bias, | |||||
| bias_init=conv_inner.bias_init, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype) | |||||
| # change original network BatchNormal OP parameters to quant network | |||||
| conv_inner.batchnorm.gamma = subcell.batchnorm.gamma | |||||
| conv_inner.batchnorm.beta = subcell.batchnorm.beta | |||||
| conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean | |||||
| conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance | |||||
| del subcell.batchnorm | |||||
| subcell.batchnorm = None | |||||
| subcell.has_bn = False | |||||
| else: | |||||
| conv_inner = quant.Conv2dQuant(conv_inner.in_channels, | |||||
| conv_inner.out_channels, | |||||
| kernel_size=conv_inner.kernel_size, | |||||
| stride=conv_inner.stride, | |||||
| pad_mode=conv_inner.pad_mode, | |||||
| padding=conv_inner.padding, | |||||
| dilation=conv_inner.dilation, | |||||
| group=conv_inner.group, | |||||
| has_bias=conv_inner.has_bias, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype) | |||||
| # change original network Conv2D OP parameters to quant network | |||||
| conv_inner.weight = subcell.conv.weight | |||||
| if subcell.conv.has_bias: | |||||
| conv_inner.bias = subcell.conv.bias | |||||
| subcell.conv = conv_inner | |||||
| if subcell.has_act and subcell.activation is not None: | |||||
| subcell.activation = self._convert_activation(subcell.activation) | |||||
| elif subcell.after_fake: | |||||
| subcell.has_act = True | |||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||||
| quant_dtype=self.act_dtype, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| return subcell | |||||
| def _convert_dense(self, subcell): | |||||
| """ | |||||
| convert dense cell to combine dense cell | |||||
| """ | |||||
| dense_inner = subcell.dense | |||||
| dense_inner = quant.DenseQuant(dense_inner.in_channels, | |||||
| dense_inner.out_channels, | |||||
| has_bias=dense_inner.has_bias, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype) | |||||
| # change original network Dense OP parameters to quant network | |||||
| dense_inner.weight = subcell.dense.weight | |||||
| if subcell.dense.has_bias: | |||||
| dense_inner.bias = subcell.dense.bias | |||||
| subcell.dense = dense_inner | |||||
| if subcell.has_act and subcell.activation is not None: | |||||
| subcell.activation = self._convert_activation(subcell.activation) | |||||
| elif subcell.after_fake: | |||||
| subcell.has_act = True | |||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||||
| quant_dtype=self.act_dtype, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| return subcell | |||||
| def _convert_activation(self, activation): | |||||
| act_class = activation.__class__ | |||||
| if act_class not in _ACTIVATION_MAP: | |||||
| raise ValueError("Unsupported activation in auto quant: ", act_class) | |||||
| return _ACTIVATION_MAP[act_class](activation=activation, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.act_dtype) | |||||
| __all__ = ["export", "manual_export"] | |||||
| class ExportToQuantInferNetwork: | class ExportToQuantInferNetwork: | ||||
| """ | """ | ||||
| @@ -499,7 +204,7 @@ class ExportToQuantInferNetwork: | |||||
| change = True | change = True | ||||
| elif isinstance(subcell, _AddFakeQuantAfterSubCell): | elif isinstance(subcell, _AddFakeQuantAfterSubCell): | ||||
| op = subcell.subcell | op = subcell.subcell | ||||
| if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): | |||||
| if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive): | |||||
| if self.is_mindir: | if self.is_mindir: | ||||
| op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) | op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) | ||||
| op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) | op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) | ||||
| @@ -553,106 +258,6 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' | |||||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) | serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) | ||||
| def convert_quant_network(network, | |||||
| bn_fold=True, | |||||
| freeze_bn=10000000, | |||||
| quant_delay=(0, 0), | |||||
| quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), | |||||
| per_channel=(False, False), | |||||
| symmetric=(False, False), | |||||
| narrow_range=(False, False) | |||||
| ): | |||||
| r""" | |||||
| Create quantization aware training network. | |||||
| Args: | |||||
| network (Cell): Obtain a pipeline through network for saving graph summary. | |||||
| bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True. | |||||
| freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. | |||||
| quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during | |||||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||||
| quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first | |||||
| element represent weights and second element represent data flow. | |||||
| Default: (QuantDtype.INT8, QuantDtype.INT8) | |||||
| per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` | |||||
| then base on per channel otherwise base on per layer. The first element represent weights | |||||
| and second element represent data flow. Default: (False, False) | |||||
| symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on | |||||
| symmetric otherwise base on asymmetric. The first element represent weights and second | |||||
| element represent data flow. Default: (False, False) | |||||
| narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. | |||||
| The first element represents weights and the second element represents data flow. Default: (False, False) | |||||
| Returns: | |||||
| Cell, Network which has change to quantization aware training network cell. | |||||
| """ | |||||
| support_device = ["Ascend", "GPU"] | |||||
| def convert2list(name, value): | |||||
| if not isinstance(value, list) and not isinstance(value, tuple): | |||||
| value = [value] | |||||
| elif len(value) > 2: | |||||
| raise ValueError("input `{}` len should less then 2".format(name)) | |||||
| return value | |||||
| quant_delay = convert2list("quant delay", quant_delay) | |||||
| quant_dtype = convert2list("quant dtype", quant_dtype) | |||||
| per_channel = convert2list("per channel", per_channel) | |||||
| symmetric = convert2list("symmetric", symmetric) | |||||
| narrow_range = convert2list("narrow range", narrow_range) | |||||
| if context.get_context('device_target') not in support_device: | |||||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||||
| net = ConvertToQuantNetwork(network=network, | |||||
| quant_delay=quant_delay, | |||||
| bn_fold=bn_fold, | |||||
| freeze_bn=freeze_bn, | |||||
| quant_dtype=quant_dtype, | |||||
| per_channel=per_channel, | |||||
| symmetric=symmetric, | |||||
| narrow_range=narrow_range) | |||||
| return net.run() | |||||
| def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'): | |||||
| """ | |||||
| Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR. | |||||
| Args: | |||||
| network (Cell): MindSpore network produced by `convert_quant_network`. | |||||
| inputs (Tensor): Inputs of the `quantization aware training network`. | |||||
| file_name (str): File name of model to export. | |||||
| mean (int, float): Input data mean. Default: 127.5. | |||||
| std_dev (int, float): Input data variance. Default: 127.5. | |||||
| file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported | |||||
| quantization aware model. Default: 'AIR'. | |||||
| - AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of | |||||
| Ascend model. | |||||
| - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format | |||||
| for MindSpore models. | |||||
| Recommended suffix for output file is '.mindir'. | |||||
| """ | |||||
| supported_device = ["Ascend", "GPU"] | |||||
| supported_formats = ['AIR', 'MINDIR'] | |||||
| mean = Validator.check_type("mean", mean, (int, float)) | |||||
| std_dev = Validator.check_type("std_dev", std_dev, (int, float)) | |||||
| if context.get_context('device_target') not in supported_device: | |||||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||||
| if file_format not in supported_formats: | |||||
| raise ValueError('Illegal file format {}.'.format(file_format)) | |||||
| network.set_train(False) | |||||
| if file_format == "MINDIR": | |||||
| exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) | |||||
| else: | |||||
| exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False) | |||||
| deploy_net = exporter.run() | |||||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) | |||||
| class ExportManualQuantNetwork: | class ExportManualQuantNetwork: | ||||
| """ | """ | ||||
| Convert anual quantization aware network to infer network. | Convert anual quantization aware network to infer network. | ||||
| @@ -713,7 +318,7 @@ class ExportManualQuantNetwork: | |||||
| elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, | elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, | ||||
| quant.Conv2dQuant, quant.DenseQuant)): | quant.Conv2dQuant, quant.DenseQuant)): | ||||
| network, change = self._convert_subcell(network, change, name, subcell, core=False) | network, change = self._convert_subcell(network, change, name, subcell, core=False) | ||||
| elif isinstance(subcell, quant.FakeQuantWithMinMax) and self.upcell: | |||||
| elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver) and self.upcell: | |||||
| np_type = mstype.dtype_to_nptype(self.data_type) | np_type = mstype.dtype_to_nptype(self.data_type) | ||||
| _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type) | _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type) | ||||
| self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq)) | self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq)) | ||||
| @@ -721,7 +326,7 @@ class ExportManualQuantNetwork: | |||||
| network.insert_child_to_cell(self.upname, self.upcell) | network.insert_child_to_cell(self.upname, self.upcell) | ||||
| elif isinstance(subcell, _AddFakeQuantAfterSubCell): | elif isinstance(subcell, _AddFakeQuantAfterSubCell): | ||||
| op = subcell.subcell | op = subcell.subcell | ||||
| if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): | |||||
| if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive): | |||||
| if self.is_mindir: | if self.is_mindir: | ||||
| op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) | op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) | ||||
| op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) | op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) | ||||
| @@ -845,3 +450,43 @@ class ExportManualQuantNetwork: | |||||
| else: | else: | ||||
| block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) | block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) | ||||
| return block | return block | ||||
| def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'): | |||||
| """ | |||||
| Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR. | |||||
| Args: | |||||
| network (Cell): MindSpore network produced by `convert_quant_network`. | |||||
| inputs (Tensor): Inputs of the `quantization aware training network`. | |||||
| file_name (str): File name of model to export. | |||||
| mean (int, float): Input data mean. Default: 127.5. | |||||
| std_dev (int, float): Input data variance. Default: 127.5. | |||||
| file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported | |||||
| quantization aware model. Default: 'AIR'. | |||||
| - AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of | |||||
| Ascend model. | |||||
| - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format | |||||
| for MindSpore models. | |||||
| Recommended suffix for output file is '.mindir'. | |||||
| """ | |||||
| supported_device = ["Ascend", "GPU"] | |||||
| supported_formats = ['AIR', 'MINDIR'] | |||||
| mean = Validator.check_type("mean", mean, (int, float)) | |||||
| std_dev = Validator.check_type("std_dev", std_dev, (int, float)) | |||||
| if context.get_context('device_target') not in supported_device: | |||||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||||
| if file_format not in supported_formats: | |||||
| raise ValueError('Illegal file format {}.'.format(file_format)) | |||||
| network.set_train(False) | |||||
| if file_format == "MINDIR": | |||||
| exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) | |||||
| else: | |||||
| exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False) | |||||
| deploy_net = exporter.run() | |||||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) | |||||
| @@ -13,14 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | """ | ||||
| Quantization. | |||||
| User can use quantization aware to train a model. MindSpore supports quantization aware training, | |||||
| which models quantization errors in both the forward and backward passes using fake-quantization | |||||
| operations. Note that the entire computation is carried out in floating point. At the end of quantization | |||||
| aware training, MindSpore provides conversion functions to convert the trained model into lower precision. | |||||
| Compression quant module. | |||||
| """ | """ | ||||
| from .quant import convert_quant_network, export, manual_export | |||||
| __all__ = ["convert_quant_network", "export", "manual_export"] | |||||
| from .quantizer import * | |||||
| from .qat import * | |||||
| from .quant_utils import * | |||||
| @@ -0,0 +1,406 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Quantization aware training | |||||
| User can use quantization aware to train a model. MindSpore supports quantization aware training, | |||||
| which models quantization errors in both the forward and backward passes using fake-quantization | |||||
| operations. Note that the entire computation is carried out in floating point. At the end of quantization | |||||
| aware training, MindSpore provides conversion functions to convert the trained model into lower precision. | |||||
| """ | |||||
| import re | |||||
| import mindspore.context as context | |||||
| from ... import nn, ops | |||||
| from ..._checkparam import Validator, Rel | |||||
| from ...nn.layer import quant | |||||
| from ...ops import functional as F | |||||
| from ..common import QuantDtype | |||||
| from .quantizer import Quantizer, OptimizeOption | |||||
| __all__ = ["QuantizationAwareTraining"] | |||||
| _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, | |||||
| nn.ReLU6: quant.ActQuant, | |||||
| nn.Sigmoid: quant.ActQuant, | |||||
| nn.LeakyReLU: quant.LeakyReLUQuant, | |||||
| nn.HSigmoid: quant.HSigmoidQuant, | |||||
| nn.HSwish: quant.HSwishQuant} | |||||
| def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), | |||||
| quant_delay=(0, 0), | |||||
| quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), | |||||
| per_channel=(False, False), | |||||
| symmetric=(False, False), | |||||
| narrow_range=(False, False) | |||||
| ): | |||||
| r""" | |||||
| Configs the oberser type of weights and data flow with quant params. | |||||
| Args: | |||||
| quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent | |||||
| weights and second element represent data flow. | |||||
| Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) | |||||
| quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during | |||||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||||
| quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first | |||||
| element represent weights and second element represent data flow. | |||||
| Default: (QuantDtype.INT8, QuantDtype.INT8) | |||||
| per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` | |||||
| then base on per channel otherwise base on per layer. The first element represent weights | |||||
| and second element represent data flow. Default: (False, False) | |||||
| symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on | |||||
| symmetric otherwise base on asymmetric. The first element represent weights and second | |||||
| element represent data flow. Default: (False, False) | |||||
| narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. | |||||
| The first element represents weights and the second element represents data flow. Default: (False, False) | |||||
| Returns: | |||||
| QuantConfig, Contains the oberser type of weight and activation. | |||||
| """ | |||||
| weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], | |||||
| per_channel=per_channel[0], symmetric=symmetric[0], | |||||
| narrow_range=narrow_range[0]) | |||||
| act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], | |||||
| per_channel=per_channel[-1], symmetric=symmetric[-1], | |||||
| narrow_range=narrow_range[-1]) | |||||
| return quant.QuantConfig(weight=weight_observer, activation=act_observer) | |||||
| class _AddFakeQuantInput(nn.Cell): | |||||
| """ | |||||
| Add FakeQuant OP at input of the network. Only support one input case. | |||||
| """ | |||||
| def __init__(self, network, quant_delay=0): | |||||
| super(_AddFakeQuantInput, self).__init__(auto_prefix=False) | |||||
| self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, | |||||
| quant_delay=quant_delay, ema=True) | |||||
| self.fake_quant_input.update_parameters_name('fake_quant_input.') | |||||
| self.network = network | |||||
| def construct(self, data): | |||||
| data = self.fake_quant_input(data) | |||||
| output = self.network(data) | |||||
| return output | |||||
| class _AddFakeQuantAfterSubCell(nn.Cell): | |||||
| """ | |||||
| Add FakeQuant OP after of the sub Cell. | |||||
| """ | |||||
| def __init__(self, subcell, **kwargs): | |||||
| super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) | |||||
| self.subcell = subcell | |||||
| self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6, | |||||
| max_init=6, | |||||
| ema=True, | |||||
| quant_dtype=kwargs["quant_dtype"], | |||||
| quant_delay=kwargs["quant_delay"], | |||||
| per_channel=kwargs["per_channel"], | |||||
| symmetric=kwargs["symmetric"], | |||||
| narrow_range=kwargs["narrow_range"]) | |||||
| def construct(self, *data): | |||||
| output = self.subcell(*data) | |||||
| output = self.fake_quant_act(output) | |||||
| return output | |||||
| class ConvertToQuantNetwork: | |||||
| """ | |||||
| Convert network to quantization aware network | |||||
| """ | |||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||||
| def __init__(self, **kwargs): | |||||
| self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) | |||||
| self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") | |||||
| self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") | |||||
| self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") | |||||
| self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") | |||||
| self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype) | |||||
| self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype) | |||||
| self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") | |||||
| self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") | |||||
| self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") | |||||
| self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric") | |||||
| self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range") | |||||
| self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") | |||||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | |||||
| quant.DenseBnAct: self._convert_dense} | |||||
| self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"], | |||||
| quant_dtype=kwargs["quant_dtype"], | |||||
| per_channel=kwargs["per_channel"], | |||||
| symmetric=kwargs["symmetric"], | |||||
| narrow_range=kwargs["narrow_range"]) | |||||
| class QuantizationAwareTraining(Quantizer): | |||||
| r""" | |||||
| Quantizer for quantization aware training. | |||||
| Args: | |||||
| bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True. | |||||
| freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. | |||||
| quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during | |||||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||||
| quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first | |||||
| element represent weights and second element represent data flow. | |||||
| Default: (QuantDtype.INT8, QuantDtype.INT8) | |||||
| per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` | |||||
| then base on per channel otherwise base on per layer. The first element represent weights | |||||
| and second element represent data flow. Default: (False, False) | |||||
| symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on | |||||
| symmetric otherwise base on asymmetric. The first element represent weights and second | |||||
| element represent data flow. Default: (False, False) | |||||
| narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. | |||||
| The first element represents weights and the second element represents data flow. Default: (False, False) | |||||
| optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only | |||||
| support QAT. Default: OptimizeOption.QAT | |||||
| """ | |||||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||||
| def __init__(self, | |||||
| bn_fold=True, | |||||
| freeze_bn=10000000, | |||||
| quant_delay=(0, 0), | |||||
| quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), | |||||
| per_channel=(False, False), | |||||
| symmetric=(False, False), | |||||
| narrow_range=(False, False), | |||||
| optimize_option=OptimizeOption.QAT): | |||||
| """Init for QuantizationAwareTraining quantizer""" | |||||
| super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option) | |||||
| def convert2list(name, value): | |||||
| if not isinstance(value, list) and not isinstance(value, tuple): | |||||
| value = [value] | |||||
| elif len(value) > 2: | |||||
| raise ValueError("input `{}` len should less then 2".format(name)) | |||||
| return value | |||||
| quant_delay = convert2list("quant delay", quant_delay) | |||||
| quant_dtype = convert2list("quant dtype", quant_dtype) | |||||
| per_channel = convert2list("per channel", per_channel) | |||||
| symmetric = convert2list("symmetric", symmetric) | |||||
| narrow_range = convert2list("narrow range", narrow_range) | |||||
| self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay") | |||||
| self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay") | |||||
| self.bn_fold = Validator.check_bool(bn_fold, "bn fold") | |||||
| self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn") | |||||
| self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype) | |||||
| self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype) | |||||
| self.weight_channel = Validator.check_bool(per_channel[0], "per channel") | |||||
| self.act_channel = Validator.check_bool(per_channel[-1], "per channel") | |||||
| self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric") | |||||
| self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") | |||||
| self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") | |||||
| self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") | |||||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | |||||
| quant.DenseBnAct: self._convert_dense} | |||||
| self.quant_config = get_quant_config(quant_delay=quant_delay, | |||||
| quant_dtype=quant_dtype, | |||||
| per_channel=per_channel, | |||||
| symmetric=symmetric, | |||||
| narrow_range=narrow_range) | |||||
| def _convert_op_name(self, name): | |||||
| pattern = re.compile(r'([A-Z]{1})') | |||||
| name_new = re.sub(pattern, r'_\1', name).lower() | |||||
| if name_new[0] == '_': | |||||
| name_new = name_new[1:] | |||||
| return name_new | |||||
| def quantize(self, network): | |||||
| support_device = ["Ascend", "GPU"] | |||||
| if context.get_context('device_target') not in support_device: | |||||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||||
| if OptimizeOption.QAT in self.optimize_option: | |||||
| network.update_cell_prefix() | |||||
| network = self._convert_subcells2quant(network) | |||||
| network.update_cell_type("quant") | |||||
| return network | |||||
| def _convert_subcells2quant(self, network): | |||||
| """ | |||||
| convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell | |||||
| """ | |||||
| cells = network.name_cells() | |||||
| change = False | |||||
| for name in cells: | |||||
| subcell = cells[name] | |||||
| if subcell == network: | |||||
| continue | |||||
| elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): | |||||
| prefix = subcell.param_prefix | |||||
| new_subcell = self._convert_method_map[type(subcell)](subcell) | |||||
| new_subcell.update_parameters_name(prefix + '.') | |||||
| network.insert_child_to_cell(name, new_subcell) | |||||
| change = True | |||||
| else: | |||||
| self._convert_subcells2quant(subcell) | |||||
| if isinstance(network, nn.SequentialCell) and change: | |||||
| network.cell_list = list(network.cells()) | |||||
| # add FakeQuant OP after OP in while list | |||||
| add_list = [] | |||||
| for name in network.__dict__: | |||||
| if name[0] == '_': | |||||
| continue | |||||
| attr = network.__dict__[name] | |||||
| if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: | |||||
| add_list.append((name, attr)) | |||||
| for name, prim_op in add_list: | |||||
| prefix = name | |||||
| add_quant = _AddFakeQuantAfterSubCell(prim_op, | |||||
| quant_dtype=self.act_dtype, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| prefix = self._convert_op_name(prim_op.name) | |||||
| if network.param_prefix: | |||||
| prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) | |||||
| add_quant.update_parameters_name(prefix + '.') | |||||
| del network.__dict__[name] | |||||
| network.insert_child_to_cell(name, add_quant) | |||||
| return network | |||||
| def _convert_conv(self, subcell): | |||||
| """ | |||||
| convert Conv2d cell to quant cell | |||||
| """ | |||||
| conv_inner = subcell.conv | |||||
| if subcell.has_bn: | |||||
| if self.bn_fold: | |||||
| bn_inner = subcell.batchnorm | |||||
| conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, | |||||
| conv_inner.out_channels, | |||||
| kernel_size=conv_inner.kernel_size, | |||||
| stride=conv_inner.stride, | |||||
| pad_mode=conv_inner.pad_mode, | |||||
| padding=conv_inner.padding, | |||||
| dilation=conv_inner.dilation, | |||||
| group=conv_inner.group, | |||||
| eps=bn_inner.eps, | |||||
| momentum=bn_inner.momentum, | |||||
| has_bias=conv_inner.has_bias, | |||||
| bias_init=conv_inner.bias_init, | |||||
| freeze_bn=self.freeze_bn, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype, | |||||
| fake=True) | |||||
| # change original network BatchNormal OP parameters to quant network | |||||
| conv_inner.gamma = subcell.batchnorm.gamma | |||||
| conv_inner.beta = subcell.batchnorm.beta | |||||
| conv_inner.moving_mean = subcell.batchnorm.moving_mean | |||||
| conv_inner.moving_variance = subcell.batchnorm.moving_variance | |||||
| del subcell.batchnorm | |||||
| subcell.batchnorm = None | |||||
| subcell.has_bn = False | |||||
| else: | |||||
| bn_inner = subcell.batchnorm | |||||
| conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, | |||||
| conv_inner.out_channels, | |||||
| kernel_size=conv_inner.kernel_size, | |||||
| stride=conv_inner.stride, | |||||
| pad_mode=conv_inner.pad_mode, | |||||
| padding=conv_inner.padding, | |||||
| dilation=conv_inner.dilation, | |||||
| group=conv_inner.group, | |||||
| eps=bn_inner.eps, | |||||
| momentum=bn_inner.momentum, | |||||
| has_bias=conv_inner.has_bias, | |||||
| bias_init=conv_inner.bias_init, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype) | |||||
| # change original network BatchNormal OP parameters to quant network | |||||
| conv_inner.batchnorm.gamma = subcell.batchnorm.gamma | |||||
| conv_inner.batchnorm.beta = subcell.batchnorm.beta | |||||
| conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean | |||||
| conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance | |||||
| del subcell.batchnorm | |||||
| subcell.batchnorm = None | |||||
| subcell.has_bn = False | |||||
| else: | |||||
| conv_inner = quant.Conv2dQuant(conv_inner.in_channels, | |||||
| conv_inner.out_channels, | |||||
| kernel_size=conv_inner.kernel_size, | |||||
| stride=conv_inner.stride, | |||||
| pad_mode=conv_inner.pad_mode, | |||||
| padding=conv_inner.padding, | |||||
| dilation=conv_inner.dilation, | |||||
| group=conv_inner.group, | |||||
| has_bias=conv_inner.has_bias, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype) | |||||
| # change original network Conv2D OP parameters to quant network | |||||
| conv_inner.weight = subcell.conv.weight | |||||
| if subcell.conv.has_bias: | |||||
| conv_inner.bias = subcell.conv.bias | |||||
| subcell.conv = conv_inner | |||||
| if subcell.has_act and subcell.activation is not None: | |||||
| subcell.activation = self._convert_activation(subcell.activation) | |||||
| elif subcell.after_fake: | |||||
| subcell.has_act = True | |||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||||
| quant_dtype=self.act_dtype, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| return subcell | |||||
| def _convert_dense(self, subcell): | |||||
| """ | |||||
| convert dense cell to combine dense cell | |||||
| """ | |||||
| dense_inner = subcell.dense | |||||
| dense_inner = quant.DenseQuant(dense_inner.in_channels, | |||||
| dense_inner.out_channels, | |||||
| has_bias=dense_inner.has_bias, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.weight_dtype) | |||||
| # change original network Dense OP parameters to quant network | |||||
| dense_inner.weight = subcell.dense.weight | |||||
| if subcell.dense.has_bias: | |||||
| dense_inner.bias = subcell.dense.bias | |||||
| subcell.dense = dense_inner | |||||
| if subcell.has_act and subcell.activation is not None: | |||||
| subcell.activation = self._convert_activation(subcell.activation) | |||||
| elif subcell.after_fake: | |||||
| subcell.has_act = True | |||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||||
| quant_dtype=self.act_dtype, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | |||||
| symmetric=self.act_symmetric, | |||||
| narrow_range=self.act_range) | |||||
| return subcell | |||||
| def _convert_activation(self, activation): | |||||
| act_class = activation.__class__ | |||||
| if act_class not in _ACTIVATION_MAP: | |||||
| raise ValueError("Unsupported activation in auto quant: ", act_class) | |||||
| return _ACTIVATION_MAP[act_class](activation=activation, | |||||
| quant_config=self.quant_config, | |||||
| quant_dtype=self.act_dtype) | |||||
| @@ -17,6 +17,9 @@ | |||||
| import numpy as np | import numpy as np | ||||
| __all__ = ["load_nonquant_param_into_quant_net"] | |||||
| def cal_quantization_params(input_min, | def cal_quantization_params(input_min, | ||||
| input_max, | input_max, | ||||
| data_type, | data_type, | ||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Base Class of Quantizer.""" | |||||
| from abc import ABC, abstractmethod | |||||
| from enum import Enum | |||||
| __all__ = ["OptimizeOption", "Quantizer"] | |||||
| class OptimizeOption(Enum): | |||||
| """ | |||||
| An enum for the model quantization optimize option. | |||||
| """ | |||||
| # using quantization aware training | |||||
| QAT = "QAT" | |||||
| def __str__(self): | |||||
| return self.value | |||||
| class Quantizer(ABC): | |||||
| """ | |||||
| Base class of Quantizer. You can implement different kind of quantizer to get different quantization result. | |||||
| Notes: | |||||
| This class is an abstract class. | |||||
| Args: | |||||
| optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: None. | |||||
| """ | |||||
| def __init__(self, | |||||
| optimize_option=None): | |||||
| if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple): | |||||
| optimize_option = [optimize_option] | |||||
| self.optimize_option = optimize_option | |||||
| @abstractmethod | |||||
| def quantize(self, network): | |||||
| pass | |||||
| @@ -30,7 +30,7 @@ from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore._checkparam import check_input_data, Validator | from mindspore._checkparam import check_input_data, Validator | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.export import quant_export | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", | __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", | ||||
| @@ -596,14 +596,14 @@ def _quant_export(network, *inputs, file_format, **kwargs): | |||||
| network.set_train(False) | network.set_train(False) | ||||
| if file_format == "MINDIR": | if file_format == "MINDIR": | ||||
| if quant_mode == 'MANUAL': | if quant_mode == 'MANUAL': | ||||
| exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) | |||||
| exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) | |||||
| else: | else: | ||||
| exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) | |||||
| exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) | |||||
| else: | else: | ||||
| if quant_mode == 'MANUAL': | if quant_mode == 'MANUAL': | ||||
| exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs) | |||||
| exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs) | |||||
| else: | else: | ||||
| exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) | |||||
| exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) | |||||
| deploy_net = exporter.run() | deploy_net = exporter.run() | ||||
| return deploy_net | return deploy_net | ||||
| @@ -25,7 +25,7 @@ from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | from src.lenet_fusion import LeNet5 as LeNet5Fusion | ||||
| @@ -47,8 +47,12 @@ if __name__ == "__main__": | |||||
| # define fusion network | # define fusion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, | |||||
| per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(quant_delay=0, | |||||
| bn_fold=False, | |||||
| freeze_bn=10000, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # define loss | # define loss | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
| @@ -22,7 +22,7 @@ import numpy as np | |||||
| import mindspore | import mindspore | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | ||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| @@ -44,8 +44,12 @@ if __name__ == "__main__": | |||||
| # define fusion network | # define fusion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, | |||||
| per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(quant_delay=0, | |||||
| bn_fold=False, | |||||
| freeze_bn=10000, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # load quantization aware network checkpoint | # load quantization aware network checkpoint | ||||
| param_dict = load_checkpoint(args.ckpt_path) | param_dict = load_checkpoint(args.ckpt_path) | ||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| @@ -26,8 +26,8 @@ from mindspore.train.serialization import load_checkpoint | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| @@ -59,8 +59,11 @@ if __name__ == "__main__": | |||||
| load_nonquant_param_into_quant_net(network, param_dict) | load_nonquant_param_into_quant_net(network, param_dict) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], | |||||
| quantizer = QuantizationAwareTraining(quant_delay=900, | |||||
| bn_fold=False, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | symmetric=[True, False]) | ||||
| network = quantizer.quantize(network) | |||||
| # define network loss | # define network loss | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
| @@ -21,7 +21,7 @@ from mindspore import context | |||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from src.mobilenetV2 import mobilenetV2 | from src.mobilenetV2 import mobilenetV2 | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| @@ -51,7 +51,10 @@ if __name__ == '__main__': | |||||
| # define fusion network | # define fusion network | ||||
| network = mobilenetV2(num_classes=config_device_target.num_classes) | network = mobilenetV2(num_classes=config_device_target.num_classes) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # define network loss | # define network loss | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||||
| @@ -21,7 +21,7 @@ import mindspore | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from src.mobilenetV2 import mobilenetV2 | from src.mobilenetV2 import mobilenetV2 | ||||
| from src.config import config_ascend_quant | from src.config import config_ascend_quant | ||||
| @@ -42,7 +42,10 @@ if __name__ == '__main__': | |||||
| # define fusion network | # define fusion network | ||||
| network = mobilenetV2(num_classes=cfg.num_classes) | network = mobilenetV2(num_classes=cfg.num_classes) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # load checkpoint | # load checkpoint | ||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | param_dict = load_checkpoint(args_opt.checkpoint_path) | ||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| @@ -26,8 +26,8 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | ||||
| from mindspore.train.serialization import load_checkpoint | from mindspore.train.serialization import load_checkpoint | ||||
| from mindspore.communication.management import init, get_group_size, get_rank | from mindspore.communication.management import init, get_group_size, get_rank | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| @@ -99,10 +99,10 @@ def train_on_ascend(): | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | param_dict = load_checkpoint(args_opt.pre_trained) | ||||
| load_nonquant_param_into_quant_net(network, param_dict) | load_nonquant_param_into_quant_net(network, param_dict) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, | |||||
| bn_fold=True, | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | per_channel=[True, False], | ||||
| symmetric=[True, False]) | symmetric=[True, False]) | ||||
| network = quantizer.quantize(network) | |||||
| # get learning rate | # get learning rate | ||||
| lr = Tensor(get_lr(global_step=config.start_epoch * step_size, | lr = Tensor(get_lr(global_step=config.start_epoch * step_size, | ||||
| @@ -162,12 +162,12 @@ def train_on_gpu(): | |||||
| load_nonquant_param_into_quant_net(network, param_dict) | load_nonquant_param_into_quant_net(network, param_dict) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, | |||||
| bn_fold=True, | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | per_channel=[True, False], | ||||
| symmetric=[True, False], | symmetric=[True, False], | ||||
| freeze_bn=1000000, | freeze_bn=1000000, | ||||
| quant_delay=step_size * 2) | quant_delay=step_size * 2) | ||||
| network = quantizer.quantize(network) | |||||
| # get learning rate | # get learning rate | ||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | ||||
| @@ -26,7 +26,7 @@ from models.resnet_quant_manual import resnet50_quant #manually construct quanta | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| parser = argparse.ArgumentParser(description='Image classification') | parser = argparse.ArgumentParser(description='Image classification') | ||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | ||||
| @@ -43,12 +43,13 @@ if args_opt.device_target == "Ascend": | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| # define fusion network | # define fusion network | ||||
| net = resnet50_quant(class_num=config.class_num) | |||||
| network = resnet50_quant(class_num=config.class_num) | |||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| net = quant.convert_quant_network(net, | |||||
| bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # define network loss | # define network loss | ||||
| if not config.use_label_smooth: | if not config.use_label_smooth: | ||||
| config.label_smooth_factor = 0.0 | config.label_smooth_factor = 0.0 | ||||
| @@ -65,13 +66,13 @@ if __name__ == '__main__': | |||||
| # load checkpoint | # load checkpoint | ||||
| if args_opt.checkpoint_path: | if args_opt.checkpoint_path: | ||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | param_dict = load_checkpoint(args_opt.checkpoint_path) | ||||
| not_load_param = load_param_into_net(net, param_dict) | |||||
| not_load_param = load_param_into_net(network, param_dict) | |||||
| if not_load_param: | if not_load_param: | ||||
| raise ValueError("Load param into net fail!") | |||||
| net.set_train(False) | |||||
| raise ValueError("Load param into network fail!") | |||||
| network.set_train(False) | |||||
| # define model | # define model | ||||
| model = Model(net, loss_fn=loss, metrics={'acc'}) | |||||
| model = Model(network, loss_fn=loss, metrics={'acc'}) | |||||
| print("============== Starting Validation ==============") | print("============== Starting Validation ==============") | ||||
| res = model.eval(dataset) | res = model.eval(dataset) | ||||
| @@ -17,14 +17,14 @@ import numpy as np | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant | |||||
| from mindspore.train.quant import quant | |||||
| from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant | |||||
| from mindspore.compression.quant import qat | |||||
| _ema_decay = 0.999 | _ema_decay = 0.999 | ||||
| _symmetric = True | _symmetric = True | ||||
| _fake = True | _fake = True | ||||
| _per_channel = True | _per_channel = True | ||||
| _quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) | |||||
| _quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) | |||||
| def _weight_variable(shape, factor=0.01): | def _weight_variable(shape, factor=0.01): | ||||
| @@ -90,8 +90,8 @@ class ConvBNReLU(nn.Cell): | |||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||
| super(ConvBNReLU, self).__init__() | super(ConvBNReLU, self).__init__() | ||||
| padding = (kernel_size - 1) // 2 | padding = (kernel_size - 1) // 2 | ||||
| conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, | |||||
| group=groups, fake=_fake, quant_config=_quant_config) | |||||
| conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, | |||||
| group=groups, fake=_fake, quant_config=_quant_config) | |||||
| layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] | layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] | ||||
| self.features = nn.SequentialCell(layers) | self.features = nn.SequentialCell(layers) | ||||
| @@ -126,14 +126,14 @@ class ResidualBlock(nn.Cell): | |||||
| channel = out_channel // self.expansion | channel = out_channel // self.expansion | ||||
| self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) | self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) | ||||
| self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) | self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) | ||||
| self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, pad_mode='same', padding=0), | |||||
| self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, pad_mode='same', padding=0), | |||||
| FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) | FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) | ||||
| ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, | |||||
| pad_mode='same', padding=0) | |||||
| ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, | |||||
| pad_mode='same', padding=0) | |||||
| self.down_sample = False | self.down_sample = False | ||||
| @@ -142,20 +142,19 @@ class ResidualBlock(nn.Cell): | |||||
| self.down_sample_layer = None | self.down_sample_layer = None | ||||
| if self.down_sample: | if self.down_sample: | ||||
| self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=stride, | |||||
| pad_mode='same', padding=0), | |||||
| self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=stride, | |||||
| pad_mode='same', padding=0), | |||||
| FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, | FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, | ||||
| symmetric=False) | symmetric=False) | ||||
| ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, | |||||
| fake=_fake, | |||||
| quant_config=\ | |||||
| _quant_config, | |||||
| kernel_size=1, | |||||
| stride=stride, | |||||
| pad_mode='same', | |||||
| padding=0) | |||||
| ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel, | |||||
| fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, | |||||
| stride=stride, | |||||
| pad_mode='same', | |||||
| padding=0) | |||||
| self.add = nn.TensorAddQuant() | self.add = nn.TensorAddQuant() | ||||
| self.relu = P.ReLU() | self.relu = P.ReLU() | ||||
| @@ -25,8 +25,8 @@ from mindspore.context import ParallelMode | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.serialization import load_checkpoint | from mindspore.train.serialization import load_checkpoint | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from mindspore.communication.management import init | from mindspore.communication.management import init | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.common.initializer as weight_init | import mindspore.common.initializer as weight_init | ||||
| @@ -113,7 +113,10 @@ if __name__ == '__main__': | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # get learning rate | # get learning rate | ||||
| lr = get_lr(lr_init=config.lr_init, | lr = get_lr(lr_init=config.lr_init, | ||||
| @@ -29,7 +29,7 @@ from mindspore.context import ParallelMode | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from src.yolo import YOLOV3DarkNet53 | from src.yolo import YOLOV3DarkNet53 | ||||
| from src.logger import get_logger | from src.logger import get_logger | ||||
| @@ -265,10 +265,10 @@ def test(): | |||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| if config.quantization_aware: | if config.quantization_aware: | ||||
| network = quant.convert_quant_network(network, | |||||
| bn_fold=True, | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | per_channel=[True, False], | ||||
| symmetric=[True, False]) | symmetric=[True, False]) | ||||
| network = quantizer.quantize(network) | |||||
| args.logger.info(args.pretrained) | args.logger.info(args.pretrained) | ||||
| if os.path.isfile(args.pretrained): | if os.path.isfile(args.pretrained): | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """hub config.""" | """hub config.""" | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from src.yolo import YOLOV3DarkNet53 | from src.yolo import YOLOV3DarkNet53 | ||||
| from src.config import ConfigYOLOV3DarkNet53 | from src.config import ConfigYOLOV3DarkNet53 | ||||
| @@ -24,9 +24,9 @@ def create_network(name, *args, **kwargs): | |||||
| config = ConfigYOLOV3DarkNet53() | config = ConfigYOLOV3DarkNet53() | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| if config.quantization_aware: | if config.quantization_aware: | ||||
| yolov3_darknet53_quant = quant.convert_quant_network(yolov3_darknet53_quant, | |||||
| bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| yolov3_darknet53_quant = quantizer.quantize(yolov3_darknet53_quant) | |||||
| return yolov3_darknet53_quant | return yolov3_darknet53_quant | ||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | raise NotImplementedError(f"{name} is not implemented in the repo") | ||||
| @@ -27,7 +27,7 @@ from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from mindspore.train.callback import ModelCheckpoint, RunContext | from mindspore.train.callback import ModelCheckpoint, RunContext | ||||
| from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper | from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper | ||||
| @@ -168,10 +168,10 @@ def train(): | |||||
| config = ConfigYOLOV3DarkNet53() | config = ConfigYOLOV3DarkNet53() | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| if config.quantization_aware: | if config.quantization_aware: | ||||
| network = quant.convert_quant_network(network, | |||||
| bn_fold=True, | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | per_channel=[True, False], | ||||
| symmetric=[True, False]) | symmetric=[True, False]) | ||||
| network = quantizer.quantize(network) | |||||
| network = YoloWithLossCell(network) | network = YoloWithLossCell(network) | ||||
| args.logger.info('finish get network') | args.logger.info('finish get network') | ||||
| @@ -26,8 +26,8 @@ from mindspore.nn.metrics import Accuracy | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net | |||||
| from dataset import create_dataset | from dataset import create_dataset | ||||
| from config import nonquant_cfg, quant_cfg | from config import nonquant_cfg, quant_cfg | ||||
| from lenet import LeNet5 | from lenet import LeNet5 | ||||
| @@ -73,8 +73,11 @@ def train_lenet_quant(): | |||||
| load_nonquant_param_into_quant_net(network, param_dict) | load_nonquant_param_into_quant_net(network, param_dict) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], | |||||
| symmetric=[False, False]) | |||||
| quantizer = QuantizationAwareTraining(quant_delay=900, | |||||
| bn_fold=False, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # define network loss | # define network loss | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
| @@ -103,8 +106,12 @@ def eval_quant(): | |||||
| # define fusion network | # define fusion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, | |||||
| per_channel=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(quant_delay=0, | |||||
| bn_fold=False, | |||||
| freeze_bn=10000, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # define loss | # define loss | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
| @@ -131,8 +138,12 @@ def export_lenet(): | |||||
| # define fusion network | # define fusion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, | |||||
| per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(quant_delay=0, | |||||
| bn_fold=False, | |||||
| freeze_bn=10000, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # export network | # export network | ||||
| inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) | inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) | ||||
| @@ -23,7 +23,7 @@ from mindspore import context | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from dataset import create_dataset | from dataset import create_dataset | ||||
| @@ -84,10 +84,10 @@ def test_mobilenetv2_quant(): | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| network = quant.convert_quant_network(network, | |||||
| bn_fold=True, | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | per_channel=[True, False], | ||||
| symmetric=[True, False]) | symmetric=[True, False]) | ||||
| network = quantizer.quantize(network) | |||||
| # get learning rate | # get learning rate | ||||
| lr = Tensor(get_lr(global_step=config.start_epoch * step_size, | lr = Tensor(get_lr(global_step=config.start_epoch * step_size, | ||||
| @@ -18,14 +18,14 @@ import mindspore.nn as nn | |||||
| import mindspore.common.initializer as weight_init | import mindspore.common.initializer as weight_init | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant | |||||
| from mindspore.train.quant import quant | |||||
| from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant | |||||
| from mindspore.compression.quant import qat | |||||
| _ema_decay = 0.999 | _ema_decay = 0.999 | ||||
| _symmetric = True | _symmetric = True | ||||
| _fake = True | _fake = True | ||||
| _per_channel = True | _per_channel = True | ||||
| _quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) | |||||
| _quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) | |||||
| def _weight_variable(shape, factor=0.01): | def _weight_variable(shape, factor=0.01): | ||||
| @@ -91,8 +91,8 @@ class ConvBNReLU(nn.Cell): | |||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||
| super(ConvBNReLU, self).__init__() | super(ConvBNReLU, self).__init__() | ||||
| padding = (kernel_size - 1) // 2 | padding = (kernel_size - 1) // 2 | ||||
| conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, | |||||
| group=groups, fake=_fake, quant_config=_quant_config) | |||||
| conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, | |||||
| group=groups, fake=_fake, quant_config=_quant_config) | |||||
| layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] | layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] | ||||
| self.features = nn.SequentialCell(layers) | self.features = nn.SequentialCell(layers) | ||||
| @@ -127,14 +127,14 @@ class ResidualBlock(nn.Cell): | |||||
| channel = out_channel // self.expansion | channel = out_channel // self.expansion | ||||
| self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) | self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) | ||||
| self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) | self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) | ||||
| self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, pad_mode='same', padding=0), | |||||
| self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, pad_mode='same', padding=0), | |||||
| FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) | FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) | ||||
| ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, | |||||
| pad_mode='same', padding=0) | |||||
| ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=1, | |||||
| pad_mode='same', padding=0) | |||||
| self.down_sample = False | self.down_sample = False | ||||
| @@ -143,20 +143,19 @@ class ResidualBlock(nn.Cell): | |||||
| self.down_sample_layer = None | self.down_sample_layer = None | ||||
| if self.down_sample: | if self.down_sample: | ||||
| self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=stride, | |||||
| pad_mode='same', padding=0), | |||||
| self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, stride=stride, | |||||
| pad_mode='same', padding=0), | |||||
| FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, | FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, | ||||
| symmetric=False) | symmetric=False) | ||||
| ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, | |||||
| fake=_fake, | |||||
| quant_config=\ | |||||
| _quant_config, | |||||
| kernel_size=1, | |||||
| stride=stride, | |||||
| pad_mode='same', | |||||
| padding=0) | |||||
| ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel, | |||||
| fake=_fake, | |||||
| quant_config=_quant_config, | |||||
| kernel_size=1, | |||||
| stride=stride, | |||||
| pad_mode='same', | |||||
| padding=0) | |||||
| self.add = nn.TensorAddQuant() | self.add = nn.TensorAddQuant() | ||||
| self.relu = P.ReLU() | self.relu = P.ReLU() | ||||
| @@ -22,7 +22,7 @@ from mindspore import context | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.quant import quant | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore import set_seed | from mindspore import set_seed | ||||
| from resnet_quant_manual import resnet50_quant | from resnet_quant_manual import resnet50_quant | ||||
| @@ -89,10 +89,10 @@ def test_resnet50_quant(): | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| # convert fusion network to quantization aware network | # convert fusion network to quantization aware network | ||||
| net = quant.convert_quant_network(net, | |||||
| bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| net = quantizer.quantize(net) | |||||
| # get learning rate | # get learning rate | ||||
| lr = Tensor(get_lr(lr_init=config.lr_init, | lr = Tensor(get_lr(lr_init=config.lr_init, | ||||
| @@ -19,7 +19,8 @@ import pytest | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.train.quant import quant as qat | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | |||||
| from mindspore.compression.export import quant_export | |||||
| from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 | from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| @@ -66,27 +67,35 @@ class LeNet5(nn.Cell): | |||||
| def test_qat_lenet(): | def test_qat_lenet(): | ||||
| img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | ||||
| net = LeNet5() | net = LeNet5() | ||||
| net = qat.convert_quant_network( | |||||
| net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| net = quantizer.quantize(net) | |||||
| # should load the checkpoint. mock here | # should load the checkpoint. mock here | ||||
| net.init_parameters_data() | net.init_parameters_data() | ||||
| qat.export(net, img, file_name="quant.pb") | |||||
| quant_export.export(net, img, file_name="quant.pb") | |||||
| @pytest.mark.skip(reason="no `te.lang.cce` in ut env") | @pytest.mark.skip(reason="no `te.lang.cce` in ut env") | ||||
| def test_qat_mobile_per_channel_tf(): | def test_qat_mobile_per_channel_tf(): | ||||
| network = mobilenetV2(num_classes=1000) | network = mobilenetV2(num_classes=1000) | ||||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | ||||
| network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[True, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # should load the checkpoint. mock here | # should load the checkpoint. mock here | ||||
| network.init_parameters_data() | network.init_parameters_data() | ||||
| qat.export(network, img, file_name="quant.pb") | |||||
| quant_export.export(network, img, file_name="quant.pb") | |||||
| @pytest.mark.skip(reason="no `te.lang.cce` in ut env") | @pytest.mark.skip(reason="no `te.lang.cce` in ut env") | ||||
| def test_qat_mobile_per_channel_ff(): | def test_qat_mobile_per_channel_ff(): | ||||
| network = mobilenetV2(num_classes=1000) | network = mobilenetV2(num_classes=1000) | ||||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | ||||
| network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, False], symmetric=[True, False]) | |||||
| quantizer = QuantizationAwareTraining(bn_fold=True, | |||||
| per_channel=[False, False], | |||||
| symmetric=[True, False]) | |||||
| network = quantizer.quantize(network) | |||||
| # should load the checkpoint. mock here | # should load the checkpoint. mock here | ||||
| network.init_parameters_data() | network.init_parameters_data() | ||||
| qat.export(network, img, file_name="quant.pb") | |||||
| quant_export.export(network, img, file_name="quant.pb") | |||||