|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- Quantization aware training
-
- User can use quantization aware to train a model. MindSpore supports quantization aware training,
- which models quantization errors in both the forward and backward passes using fake-quantization
- operations. Note that the entire computation is carried out in floating point. At the end of quantization
- aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
- """
-
- import re
- import mindspore.context as context
- import numpy as np
- from ... import nn, ops
- from ..._checkparam import Validator, Rel
- from ...nn.layer import quant
- from ...ops import functional as F
- from ..common import QuantDtype
- from .quantizer import Quantizer, OptimizeOption
- from .quant_utils import compute_KL_threshold
-
-
- __all__ = ["QuantizationAwareTraining", "create_quant_config"]
-
-
- def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver),
- quant_delay=(0, 0),
- quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
- per_channel=(False, False),
- symmetric=(False, False),
- narrow_range=(False, False),
- mode="DEFAULT"):
- r"""
- Config the observer type of weights and data flow with quant params.
-
- Args:
- quant_observer (Union[Observer, list, tuple]): The observer type to do quantization. The first element
- represents weights and second element represents data flow.
- Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver)
- quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during
- eval. The first element represents weights and second element represents data flow. Default: (0, 0)
- quant_dtype (Union[QuantDtype, list, tuple]): Datatype to use for quantize weights and activations. The first
- element represents weights and second element represents data flow.
- Default: (QuantDtype.INT8, QuantDtype.INT8)
- per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True`
- then base on per channel otherwise base on per layer. The first element represents weights
- and second element represents data flow, and second element must be `False` now. Default: (False, False)
- symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
- base on symmetric otherwise base on asymmetric. The first element represents weights and second
- element represents data flow. Default: (False, False)
- narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not.
- The first element represents weights and the second element represents data flow. Default: (False, False)
- mode (String): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported.
- Default: ("DEFAULT")
-
- Returns:
- QuantConfig, Contains the observer type of weight and activation.
- """
- if per_channel[-1]:
- raise ValueError("Arg 'per_channel' second element must be 'False'.")
- weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
- per_channel=per_channel[0], symmetric=symmetric[0],
- narrow_range=narrow_range[0], mode=mode)
- act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
- per_channel=per_channel[-1], symmetric=symmetric[-1],
- narrow_range=narrow_range[-1], mode=mode)
- return quant.QuantConfig(weight=weight_observer, activation=act_observer)
-
-
- class _AddFakeQuantInput(nn.Cell):
- """
- Add FakeQuant OP at input of the network. Only support one input case.
- """
-
- def __init__(self, network, quant_delay=0):
- super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
- self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
- quant_delay=quant_delay, ema=True)
- self.fake_quant_input.update_parameters_name('fake_quant_input.')
- self.network = network
-
- def construct(self, data):
- data = self.fake_quant_input(data)
- output = self.network(data)
- return output
-
-
- class _AddFakeQuantAfterSubCell(nn.Cell):
- """
- Add FakeQuant OP after of the sub Cell.
- """
-
- def __init__(self, subcell, **kwargs):
- super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
- self.subcell = subcell
- self.mode = "DEFAULT"
- self.max_init = 6
- self.min_init = -6
-
- if OptimizeOption.LEARNED_SCALE in kwargs["optimize_option"]:
- self.mode = "LEARNED_SCALE"
- self.max_init = 16
- self.min_init = -16
-
- self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=self.min_init,
- max_init=self.max_init,
- ema=True,
- quant_dtype=kwargs["quant_dtype"],
- quant_delay=kwargs["quant_delay"],
- per_channel=kwargs["per_channel"],
- symmetric=kwargs["symmetric"],
- narrow_range=kwargs["narrow_range"],
- mode=self.mode)
-
- def construct(self, *data):
- output = self.subcell(*data)
- output = self.fake_quant_act(output)
- return output
-
-
- class QuantizationAwareTraining(Quantizer):
- r"""
- Quantizer for quantization aware training.
-
- Args:
- bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
- freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
- quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during
- eval. The first element represents weights and second element represents data flow. Default: (0, 0)
- quant_dtype (Union[QuantDtype, list, tuple]): Datatype to use for quantize weights and activations. The first
- element represents weights and second element represents data flow. It is necessary to consider the
- precision support of hardware devices in the practical quantization infer scenario.
- Default: (QuantDtype.INT8, QuantDtype.INT8)
- per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True`
- then base on per channel otherwise base on per layer. The first element represents weights
- and second element represents data flow, and second element must be `False` now. Default: (False, False)
- symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
- base on symmetric otherwise base on asymmetric. The first element represents weights and second
- element represents data flow. Default: (False, False)
- narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not.
- The first element represents weights and the second element represents data flow. Default: (False, False)
- optimize_option (Union[OptimizeOption, list, tuple]): Specifies the quant algorithm and options, currently only
- support QAT and LEARNED_SCALE (Note that, if both QAT and LEARNED_SCALE are configured, LEARNED_SCALE has
- a higher priority. LEARNED_SCALE currently only work under some constraints, which includes: freeze_bn=0,
- quant_delay=0, symmetric=Ture, narrow_range=True, More specifically, for operators such as ReLu and ReLu6,
- which only have positive values, we add a negative truncation to optimize this scenario, and narrow_range
- will automatically match to False). Default: OptimizeOption.QAT
- one_conv_fold (bool): Flag to used one conv bn fold ops for simulation inference operation. Default: True.
-
- Examples:
- >>> class LeNet5(nn.Cell):
- ... def __init__(self, num_class=10, channel=1):
- ... super(LeNet5, self).__init__()
- ... self.type = "fusion"
- ... self.num_class = num_class
- ...
- ... # change `nn.Conv2d` to `nn.Conv2dBnAct`
- ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
- ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
- ... # change `nn.Dense` to `nn.DenseBnAct`
- ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
- ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
- ... self.fc3 = nn.DenseBnAct(84, self.num_class)
- ...
- ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
- ... self.flatten = nn.Flatten()
- ...
- ... def construct(self, x):
- ... x = self.conv1(x)
- ... x = self.max_pool2d(x)
- ... x = self.conv2(x)
- ... x = self.max_pool2d(x)
- ... x = self.flatten(x)
- ... x = self.fc1(x)
- ... x = self.fc2(x)
- ... x = self.fc3(x)
- ... return x
- ...
- >>> net = LeNet5()
- >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
- >>> net_qat = quantizer.quantize(net)
- """
- __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
-
- def __init__(self,
- bn_fold=True,
- freeze_bn=10000000,
- quant_delay=(0, 0),
- quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
- per_channel=(False, False),
- symmetric=(False, False),
- narrow_range=(False, False),
- optimize_option=OptimizeOption.QAT,
- one_conv_fold=True):
- """Init for QuantizationAwareTraining quantizer"""
- super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option)
-
- def convert2list(name, value):
- if not isinstance(value, list) and not isinstance(value, tuple):
- value = [value]
- elif len(value) > 2:
- raise ValueError("input `{}` len should less then 2".format(name))
- return value
-
- quant_delay = convert2list("quant delay", quant_delay)
- quant_dtype = convert2list("quant dtype", quant_dtype)
- per_channel = convert2list("per channel", per_channel)
- symmetric = convert2list("symmetric", symmetric)
- narrow_range = convert2list("narrow range", narrow_range)
-
- self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay")
- self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay")
- self.bn_fold = Validator.check_bool(bn_fold, "bn fold")
- self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn")
- self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype)
- self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype)
- self.weight_channel = Validator.check_bool(per_channel[0], "per channel")
- self.act_channel = Validator.check_bool(per_channel[-1], "per channel")
- self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric")
- self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
- self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
- self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
- self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold")
- self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv,
- nn.DenseBnAct: self._convert_dense}
- self.mode = "DEFAULT"
- if OptimizeOption.LEARNED_SCALE in self.optimize_option:
- self.mode = "LEARNED_SCALE"
- if not self.weight_symmetric or not self.act_symmetric:
- raise ValueError("OptimizeOption.LEARNED_SCALE currently only support "
- "symmetric=(True, True) for quant")
- if not self.weight_range or not self.act_range:
- raise ValueError("OptimizeOption.LEARNED_SCALE currently only support narrow_range=(True, True) "
- "for quant")
- if self.freeze_bn != 0:
- raise ValueError("OptimizeOption.LEARNED_SCALE currently only support freeze_bn equal to 0, "
- "but get freeze_bn={}".format(self.freeze_bn))
- if self.weight_qdelay != 0 or self.act_qdelay != 0:
- raise ValueError("OptimizeOption.LEARNED_SCALE currently only support quant_delay=(0, 0)")
- self.quant_config = create_quant_config(quant_delay=quant_delay,
- quant_dtype=quant_dtype,
- per_channel=per_channel,
- symmetric=symmetric,
- narrow_range=narrow_range,
- mode=self.mode)
- self.eps = 1e-5
-
- def _convert_op_name(self, name):
- pattern = re.compile(r'([A-Z]{1})')
- name_new = re.sub(pattern, r'_\1', name).lower()
- if name_new[0] == '_':
- name_new = name_new[1:]
- return name_new
-
- def quantize(self, network):
- """
- Quant API to convert input network to a quantization aware training network
-
- Args:
- network (Cell): network to be quantized.
-
- Examples:
- >>> net = Net()
- >>> quantizer = QuantizationAwareTraining()
- >>> net_qat = quantizer.quantize(net)
- """
- support_device = ["Ascend", "GPU"]
- if context.get_context('device_target') not in support_device:
- raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
-
- if OptimizeOption.QAT in self.optimize_option or OptimizeOption.LEARNED_SCALE in self.optimize_option:
- network.update_cell_prefix()
- network = self._convert_subcells2quant(network)
- network.update_cell_type("quant")
- return network
-
- def _convert_subcells2quant(self, network):
- """
- convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
- """
- cells = network.name_cells()
- change = False
- for name in cells:
- subcell = cells[name]
- if subcell == network:
- continue
- elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)):
- prefix = subcell.param_prefix
- new_subcell = self._convert_method_map[type(subcell)](subcell)
- new_subcell.update_parameters_name(prefix + '.')
- network.insert_child_to_cell(name, new_subcell)
- change = True
- else:
- self._convert_subcells2quant(subcell)
- if isinstance(network, nn.SequentialCell) and change:
- network.cell_list = list(network.cells())
-
- # add FakeQuant OP after OP in white list, but not including those wrapped in the below quantization cell.
- if isinstance(network, (nn.FakeQuantWithMinMaxObserver,
- nn.Conv2dBnFoldQuantOneConv,
- nn.Conv2dBnFoldQuant,
- nn.Conv2dBnWithoutFoldQuant,
- nn.Conv2dQuant,
- nn.DenseQuant,
- nn.ActQuant,
- nn.TensorAddQuant,
- nn.MulQuant)):
- return network
-
- add_list = []
- for name in network.__dict__:
- if name[0] == '_':
- continue
- attr = network.__dict__[name]
- if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
- add_list.append((name, attr))
- for name, prim_op in add_list:
- prefix = name
- add_quant = _AddFakeQuantAfterSubCell(prim_op,
- quant_dtype=self.act_dtype,
- quant_delay=self.act_qdelay,
- per_channel=self.act_channel,
- symmetric=self.act_symmetric,
- narrow_range=self.act_range,
- optimize_option=self.optimize_option)
- prefix = self._convert_op_name(prim_op.name)
- if network.param_prefix:
- prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
- add_quant.update_parameters_name(prefix + '.')
- del network.__dict__[name]
- network.insert_child_to_cell(name, add_quant)
- return network
-
- def _convert_conv(self, subcell):
- """
- convert Conv2d cell to quant cell
- """
- min_init = -6
- max_init = 6
- if OptimizeOption.LEARNED_SCALE in self.optimize_option:
- subcell_weight_para = subcell.conv.weight.data.asnumpy()
- if subcell.has_bn:
- scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
- np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
- subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
- min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype)
- self.quant_config = self.quant_config._replace(
- weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
-
- conv_inner = subcell.conv
- if subcell.has_bn:
- bn_inner = subcell.batchnorm
- if self.bn_fold:
- if self.one_conv_fold:
- conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels,
- conv_inner.out_channels,
- kernel_size=conv_inner.kernel_size,
- stride=conv_inner.stride,
- pad_mode=conv_inner.pad_mode,
- padding=conv_inner.padding,
- dilation=conv_inner.dilation,
- group=conv_inner.group,
- eps=bn_inner.eps,
- momentum=1 - bn_inner.momentum,
- has_bias=conv_inner.has_bias,
- bias_init=conv_inner.bias_init,
- quant_config=self.quant_config,
- quant_dtype=self.weight_dtype,
- fake=True)
- else:
- conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
- conv_inner.out_channels,
- kernel_size=conv_inner.kernel_size,
- stride=conv_inner.stride,
- pad_mode=conv_inner.pad_mode,
- padding=conv_inner.padding,
- dilation=conv_inner.dilation,
- group=conv_inner.group,
- eps=bn_inner.eps,
- momentum=1 - bn_inner.momentum,
- has_bias=conv_inner.has_bias,
- bias_init=conv_inner.bias_init,
- freeze_bn=self.freeze_bn,
- quant_config=self.quant_config,
- quant_dtype=self.weight_dtype,
- fake=True)
- # change original network Batch Normalization OP parameters to quant network
- conv_inner.gamma = subcell.batchnorm.gamma
- conv_inner.beta = subcell.batchnorm.beta
- conv_inner.moving_mean = subcell.batchnorm.moving_mean
- conv_inner.moving_variance = subcell.batchnorm.moving_variance
- else:
- conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
- conv_inner.out_channels,
- kernel_size=conv_inner.kernel_size,
- stride=conv_inner.stride,
- pad_mode=conv_inner.pad_mode,
- padding=conv_inner.padding,
- dilation=conv_inner.dilation,
- group=conv_inner.group,
- eps=bn_inner.eps,
- momentum=1 - bn_inner.momentum,
- has_bias=conv_inner.has_bias,
- bias_init=conv_inner.bias_init,
- quant_config=self.quant_config,
- quant_dtype=self.weight_dtype)
- # change original network Batch Normalization OP parameters to quant network
- conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
- conv_inner.batchnorm.beta = subcell.batchnorm.beta
- conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
- conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
- del subcell.batchnorm
- subcell.batchnorm = None
- subcell.has_bn = False
- else:
- conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels,
- kernel_size=conv_inner.kernel_size, stride=conv_inner.stride,
- pad_mode=conv_inner.pad_mode, padding=conv_inner.padding,
- dilation=conv_inner.dilation, group=conv_inner.group,
- has_bias=conv_inner.has_bias, quant_config=self.quant_config,
- quant_dtype=self.weight_dtype)
- # change original network Conv2D OP parameters to quant network
- conv_inner.weight = subcell.conv.weight
- if subcell.conv.has_bias:
- conv_inner.bias = subcell.conv.bias
- subcell.conv = conv_inner
- if subcell.has_act and subcell.activation is not None:
- subcell.activation = self._convert_activation(subcell.activation)
- elif subcell.after_fake:
- subcell.has_act = True
- subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype,
- quant_delay=self.act_qdelay, per_channel=self.act_channel,
- symmetric=self.act_symmetric, narrow_range=self.act_range,
- optimize_option=self.optimize_option)
- return subcell
-
- def _convert_dense(self, subcell):
- """
- convert dense cell to quant cell
- """
- min_init = -6
- max_init = 6
- if OptimizeOption.LEARNED_SCALE in self.optimize_option:
- subcell_weight_para = subcell.dense.weight.data.asnumpy()
- if subcell.has_bn:
- scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
- np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
- subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
- min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype)
- self.quant_config = self.quant_config._replace(
- weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
-
- dense_inner = subcell.dense
- dense_inner = quant.DenseQuant(dense_inner.in_channels,
- dense_inner.out_channels,
- has_bias=dense_inner.has_bias,
- quant_config=self.quant_config,
- quant_dtype=self.weight_dtype)
- # change original network Dense OP parameters to quant network
- dense_inner.weight = subcell.dense.weight
- if subcell.dense.has_bias:
- dense_inner.bias = subcell.dense.bias
- subcell.dense = dense_inner
- if subcell.has_act and subcell.activation is not None:
- subcell.activation = self._convert_activation(subcell.activation)
- elif subcell.after_fake:
- subcell.has_act = True
- subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
- quant_dtype=self.act_dtype,
- quant_delay=self.act_qdelay,
- per_channel=self.act_channel,
- symmetric=self.act_symmetric,
- narrow_range=self.act_range,
- optimize_option=self.optimize_option)
- return subcell
-
- def _convert_activation(self, activation):
- """
- convert activation cell to quant cell
- """
- act_class = activation.__class__
- act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
- act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
-
- if act_class in act_list:
- return quant.ActQuant(activation=activation,
- quant_config=self.quant_config,
- quant_dtype=self.act_dtype)
- if act_class in act_list_with_fake_before:
- return quant.ActQuant(activation=activation,
- ema=True,
- fake_before=True,
- quant_config=self.quant_config,
- quant_dtype=self.act_dtype)
- raise ValueError("Unsupported activation in auto quant: ", act_class)
-
- def _KL_init(self, subcell_weight_para, weight_dtype):
- """
- Calculate the value of max_init and min_init with compute_KL_threshold.
- """
- if self.weight_channel:
- max_init = [compute_KL_threshold(weight_para_each, weight_dtype)
- for weight_para_each in subcell_weight_para]
- min_init = [-x for x in max_init]
- else:
- max_init = [compute_KL_threshold(subcell_weight_para, weight_dtype)]
- min_init = [-x for x in max_init]
- return min_init, max_init
-
- def set_mixed_bits(self, network, strategy):
- r"""
- Set network's quantization strategy, this function is currently only valid for `LEARNED_SCALE`
- optimize_option.
-
- Inputs:
- network (Cell): input network
- strategy (List): the quantization strategy for layers that need to be quantified (eg. [[8], [8],
- ..., [6], [4], [8]]), currently only the quant_dtype for weights of the dense layer and the
- convolution layer is supported.
-
- Outputs:
- network (Cell)
- """
- if OptimizeOption.LEARNED_SCALE not in self.optimize_option:
- raise ValueError("The `set_mixed_bits` function is currently only valid for `LEARNED_SCALE` "
- "optimize_option.")
-
- self.quantizable_idx = []
- pass_cell = None
- for i, cell_and_name in enumerate(network.cells_and_names()):
- cell = cell_and_name[1]
- if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell:
- self.quantizable_idx.append(i)
-
- assert len(self.quantizable_idx) == len(strategy)
- quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(self.quantizable_idx, strategy)}
- type_map = {
- QuantDtype.INT2.num_bits: QuantDtype.INT2,
- QuantDtype.INT3.num_bits: QuantDtype.INT3,
- QuantDtype.INT4.num_bits: QuantDtype.INT4,
- QuantDtype.INT5.num_bits: QuantDtype.INT5,
- QuantDtype.INT6.num_bits: QuantDtype.INT6,
- QuantDtype.INT7.num_bits: QuantDtype.INT7,
- QuantDtype.INT8.num_bits: QuantDtype.INT8
- }
- for i, cell_and_name in enumerate(network.cells_and_names()):
- cell = cell_and_name[1]
- if i not in self.quantizable_idx:
- continue
- else:
- if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)):
- cell.weight_dtype = type_map[quantizable_layer_bit_dict[i][0]]
- if isinstance(cell, nn.Conv2dBnAct):
- subcell_weight_para = cell.conv.weight.data.asnumpy()
- if hasattr(cell.conv, 'gamma'):
- scale_factor = (cell.conv.gamma.data.asnumpy() /
- np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps))
- subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
- min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype)
- cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
- min_init=min_init,
- max_init=max_init)
- elif isinstance(cell, nn.DenseBnAct):
- subcell_weight_para = cell.dense.weight.data.asnumpy()
- if hasattr(cell.dense, 'gamma'):
- scale_factor = (cell.dense.gamma.data.asnumpy() /
- np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps))
- subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
- min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype)
- cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
- min_init=min_init,
- max_init=max_init)
- return network
|