Browse Source

add QuantConfig & modify quant cells

tags/v1.1.0
yuchaojie 5 years ago
parent
commit
a84affffd7
4 changed files with 231 additions and 428 deletions
  1. +118
    -340
      mindspore/nn/layer/quant.py
  2. +81
    -48
      mindspore/train/quant/quant.py
  3. +15
    -16
      model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py
  4. +17
    -24
      tests/st/quantization/resnet50_quant/resnet_quant_manual.py

+ 118
- 340
mindspore/nn/layer/quant.py View File

@@ -15,6 +15,7 @@
"""Quantization aware training.""" """Quantization aware training."""


from functools import partial from functools import partial
from collections import namedtuple
import numpy as np import numpy as np
from mindspore import nn from mindspore import nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@@ -34,7 +35,7 @@ from ...ops.operations import _quant_ops as Q
__all__ = [ __all__ = [
'Conv2dBnAct', 'Conv2dBnAct',
'DenseBnAct', 'DenseBnAct',
'FakeQuantWithMinMax',
'FakeQuantWithMinMaxObserver',
'Conv2dBnFoldQuant', 'Conv2dBnFoldQuant',
'Conv2dBnWithoutFoldQuant', 'Conv2dBnWithoutFoldQuant',
'Conv2dQuant', 'Conv2dQuant',
@@ -422,14 +423,14 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
symmetric=False, symmetric=False,
narrow_range=False, narrow_range=False,
quant_delay=0): quant_delay=0):
"""Initialize FakeQuantWithMinMax layer"""
"""Initialize FakeQuantWithMinMaxObserver"""
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range, symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels) num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float]) Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float]) Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init self.min_init = min_init
self.max_init = max_init self.max_init = max_init
self.quant_dtype = quant_dtype self.quant_dtype = quant_dtype
@@ -498,119 +499,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
return out return out




class FakeQuantWithMinMax(Cell):
r"""
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.

Args:
min_init (int, float): The initialized min value. Default: -6.
max_init (int, float): The initialized max value. Default: 6.
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
num_channels (int): declarate the min and max channel size, Default: 1.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.

Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.

Outputs:
Tensor, with the same type and shape as the `x`.

Examples:
>>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""

def __init__(self,
min_init=-6,
max_init=6,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
num_channels=1,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
"""Initialize FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.num_channels = num_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.is_ascend = context.get_context('device_target') == "Ascend"

# init tensor min and max for fake quant op
if self.per_channel:
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
else:
min_array = np.array([self.min_init]).astype(np.float32)
max_array = np.array([self.max_init]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)

# init fake quant relative op
if self.per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else:
quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.MinMaxUpdatePerLayer

self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
if self.is_ascend:
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_infer = self.fake_quant_train
else:
quant_fun = partial(quant_fun,
ema=self.ema,
ema_decay=ema_decay,
num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_train = quant_fun(training=True)
self.fake_quant_infer = quant_fun(training=False)

def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range,
self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.num_channels, self.quant_delay,
self.min_init, self.max_init)
return s
QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])


def construct(self, x):
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out
quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver)




class Conv2dBnFoldQuant(Cell): class Conv2dBnFoldQuant(Cell):
@@ -641,12 +532,9 @@ class Conv2dBnFoldQuant(Cell):
mean vector. Default: 'zeros'. mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'ones'. variance vector. Default: 'ones'.
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMax op. Default: True.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): The Quantization delay parameters according to the global step. Default: 0.
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
freeze_bn (int): The quantization freeze BatchNormal op is according to the global step. Default: 100000. freeze_bn (int): The quantization freeze BatchNormal op is according to the global step. Default: 100000.


Inputs: Inputs:
@@ -680,11 +568,8 @@ class Conv2dBnFoldQuant(Cell):
mean_init='zeros', mean_init='zeros',
var_init='ones', var_init='ones',
fake=True, fake=True,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8,
freeze_bn=100000): freeze_bn=100000):
"""Initialize Conv2dBnFoldQuant layer""" """Initialize Conv2dBnFoldQuant layer"""
super(Conv2dBnFoldQuant, self).__init__() super(Conv2dBnFoldQuant, self).__init__()
@@ -699,13 +584,10 @@ class Conv2dBnFoldQuant(Cell):
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = momentum
self.has_bias = has_bias self.has_bias = has_bias
self.quant_delay = quant_delay
self.freeze_bn = freeze_bn self.freeze_bn = freeze_bn
self.fake = fake self.fake = fake
self.num_bits = num_bits
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.quant_config = quant_config
self.quant_dtype = quant_dtype
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"


# initialize convolution op and Parameter # initialize convolution op and Parameter
@@ -745,16 +627,12 @@ class Conv2dBnFoldQuant(Cell):
requires_grad=False) requires_grad=False)


# initialize fake ops # initialize fake ops
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
per_channel=per_channel,
channel_axis=channel_axis, channel_axis=channel_axis,
num_channels=out_channels, num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = Q.CorrectionMul(channel_axis) self.correct_mul = Q.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend": if context.get_context('device_target') == "Ascend":
@@ -777,7 +655,7 @@ class Conv2dBnFoldQuant(Cell):
self.pad_mode, self.padding, self.dilation, self.pad_mode, self.padding, self.dilation,
self.group, self.group,
self.fake, self.freeze_bn, self.momentum, self.fake, self.freeze_bn, self.momentum,
self.quant_delay)
self.fake_quant_weight.quant_delay)
return s return s


def construct(self, x): def construct(self, x):
@@ -836,11 +714,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -868,11 +743,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
momentum=0.997, momentum=0.997,
weight_init='normal', weight_init='normal',
bias_init='zeros', bias_init='zeros',
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(Conv2dBnWithoutFoldQuant, self).__init__() super(Conv2dBnWithoutFoldQuant, self).__init__()
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size) self.kernel_size = (kernel_size, kernel_size)
@@ -886,7 +758,6 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.pad_mode = pad_mode self.pad_mode = pad_mode
self.padding = padding self.padding = padding
self.group = group self.group = group
self.quant_delay = quant_delay


self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if Validator.check_bool(has_bias): if Validator.check_bool(has_bias):
@@ -917,16 +788,12 @@ class Conv2dBnWithoutFoldQuant(Cell):
weight_shape = [out_channels, in_channels // group, *self.kernel_size] weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0 channel_axis = 0
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
per_channel=per_channel,
channel_axis=channel_axis, channel_axis=channel_axis,
num_channels=out_channels, num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)


def construct(self, x): def construct(self, x):
@@ -942,7 +809,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group, self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
self.has_bias, self.fake_quant_weight.quant_delay)
return s return s




@@ -966,11 +833,8 @@ class Conv2dQuant(Cell):
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -996,11 +860,8 @@ class Conv2dQuant(Cell):
has_bias=False, has_bias=False,
weight_init='normal', weight_init='normal',
bias_init='zeros', bias_init='zeros',
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(Conv2dQuant, self).__init__() super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size) self.kernel_size = (kernel_size, kernel_size)
@@ -1014,7 +875,6 @@ class Conv2dQuant(Cell):
self.pad_mode = pad_mode self.pad_mode = pad_mode
self.padding = padding self.padding = padding
self.group = group self.group = group
self.quant_delay = quant_delay


weight_shape = [out_channels, in_channels // group, *self.kernel_size] weight_shape = [out_channels, in_channels // group, *self.kernel_size]
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
@@ -1033,16 +893,12 @@ class Conv2dQuant(Cell):
stride=self.stride, stride=self.stride,
dilation=self.dilation, dilation=self.dilation,
group=self.group) group=self.group)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
per_channel=per_channel,
channel_axis=0, channel_axis=0,
num_channels=out_channels, num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)


def construct(self, x): def construct(self, x):
weight = self.fake_quant_weight(self.weight) weight = self.fake_quant_weight(self.weight)
@@ -1056,7 +912,7 @@ class Conv2dQuant(Cell):
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group, self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
self.has_bias, self.fake_quant_weight.quant_delay)
return s return s




@@ -1075,11 +931,8 @@ class DenseQuant(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -1093,19 +946,15 @@ class DenseQuant(Cell):
>>> result = dense_quant(input_x) >>> result = dense_quant(input_x)
""" """


def __init__(
self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(DenseQuant, self).__init__() super(DenseQuant, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
@@ -1132,16 +981,12 @@ class DenseQuant(Cell):


self.activation = get_activation(activation) self.activation = get_activation(activation)
self.activation_flag = self.activation is not None self.activation_flag = self.activation is not None
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
per_channel=per_channel,
channel_axis=0, channel_axis=0,
num_channels=out_channels, num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)


def construct(self, x): def construct(self, x):
"""Use operators to construct the Dense layer.""" """Use operators to construct the Dense layer."""
@@ -1179,16 +1024,13 @@ class ActQuant(_QuantActivation):
Quantization aware training activation function. Quantization aware training activation function.


Add the fake quant op to the end of activation op, by which the output of activation op will be truncated. Add the fake quant op to the end of activation op, by which the output of activation op will be truncated.
Please check `FakeQuantWithMinMax` for more details.
Please check `FakeQuantWithMinMaxObserver` or other observer for more details.


Args: Args:
activation (Cell): Activation cell class. activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global steps. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - The input of ReLU6Quant. - **x** (Tensor) - The input of ReLU6Quant.
@@ -1205,21 +1047,14 @@ class ActQuant(_QuantActivation):
def __init__(self, def __init__(self,
activation, activation,
ema_decay=0.999, ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(ActQuant, self).__init__() super(ActQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=False,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
self.act = activation self.act = activation


def construct(self, x): def construct(self, x):
@@ -1240,11 +1075,8 @@ class LeakyReLUQuant(_QuantActivation):
Args: Args:
activation (Cell): Activation cell class. activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - The input of LeakyReLUQuant. - **x** (Tensor) - The input of LeakyReLUQuant.
@@ -1261,30 +1093,19 @@ class LeakyReLUQuant(_QuantActivation):
def __init__(self, def __init__(self,
activation, activation,
ema_decay=0.999, ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(LeakyReLUQuant, self).__init__() super(LeakyReLUQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_before = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
self.fake_quant_act_after = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
if issubclass(activation.__class__, nn.LeakyReLU): if issubclass(activation.__class__, nn.LeakyReLU):
self.act = activation self.act = activation
else: else:
@@ -1309,11 +1130,8 @@ class HSwishQuant(_QuantActivation):
Args: Args:
activation (Cell): Activation cell class. activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - The input of HSwishQuant. - **x** (Tensor) - The input of HSwishQuant.
@@ -1330,30 +1148,19 @@ class HSwishQuant(_QuantActivation):
def __init__(self, def __init__(self,
activation, activation,
ema_decay=0.999, ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(HSwishQuant, self).__init__() super(HSwishQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_before = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
self.fake_quant_act_after = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
if issubclass(activation.__class__, nn.HSwish): if issubclass(activation.__class__, nn.HSwish):
self.act = activation self.act = activation
else: else:
@@ -1378,11 +1185,8 @@ class HSigmoidQuant(_QuantActivation):
Args: Args:
activation (Cell): Activation cell class. activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - The input of HSigmoidQuant. - **x** (Tensor) - The input of HSigmoidQuant.
@@ -1399,30 +1203,19 @@ class HSigmoidQuant(_QuantActivation):
def __init__(self, def __init__(self,
activation, activation,
ema_decay=0.999, ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(HSigmoidQuant, self).__init__() super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_before = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
self.fake_quant_act_after = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
if issubclass(activation.__class__, nn.HSigmoid): if issubclass(activation.__class__, nn.HSigmoid):
self.act = activation self.act = activation
else: else:
@@ -1446,11 +1239,8 @@ class TensorAddQuant(Cell):


Args: Args:
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - The input of TensorAddQuant. - **x** (Tensor) - The input of TensorAddQuant.
@@ -1467,21 +1257,14 @@ class TensorAddQuant(Cell):


def __init__(self, def __init__(self,
ema_decay=0.999, ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(TensorAddQuant, self).__init__() super(TensorAddQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
self.add = P.TensorAdd() self.add = P.TensorAdd()


def construct(self, x1, x2): def construct(self, x1, x2):
@@ -1498,11 +1281,8 @@ class MulQuant(Cell):


Args: Args:
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.


Inputs: Inputs:
- **x** (Tensor) - The input of MulQuant. - **x** (Tensor) - The input of MulQuant.
@@ -1510,25 +1290,23 @@ class MulQuant(Cell):
Outputs: Outputs:
Tensor, with the same type and shape as the `x`. Tensor, with the same type and shape as the `x`.


Examples:
>>> mul_quant = nn.MulQuant()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
>>> result = mul_quant(input_x, input_y)
""" """


def __init__(self, def __init__(self,
ema_decay=0.999, ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(MulQuant, self).__init__() super(MulQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
self.mul = P.Mul() self.mul = P.Mul()


def construct(self, x1, x2): def construct(self, x1, x2):


+ 81
- 48
mindspore/train/quant/quant.py View File

@@ -27,6 +27,7 @@ from ...common import Tensor
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.api import _executor from ...common.api import _executor
from ...nn.layer import quant from ...nn.layer import quant
from ...compression.common import QuantDtype
from ...ops import functional as F from ...ops import functional as F
from ...ops import operations as P from ...ops import operations as P
from ...ops.operations import _inner_ops as inner from ...ops.operations import _inner_ops as inner
@@ -41,6 +42,46 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
nn.HSwish: quant.HSwishQuant} nn.HSwish: quant.HSwishQuant}




def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
quant_delay=(0, 0),
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
per_channel=(False, False),
symmetric=(False, False),
narrow_range=(False, False)
):
r"""
Configs the oberser type of weights and data flow with quant params.

Args:
quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
weights and second element represent data flow.
Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
element represent weights and second element represent data flow.
Default: (QuantDtype.INT8, QuantDtype.INT8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False)
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: (False, False)
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
The first element represents weights and the second element represents data flow. Default: (False, False)

Returns:
QuantConfig, Contains the oberser type of weight and activation.
"""
weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
per_channel=per_channel[0], symmetric=symmetric[0],
narrow_range=narrow_range[0])
act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
per_channel=per_channel[-1], symmetric=symmetric[-1],
narrow_range=narrow_range[-1])
return quant.QuantConfig(weight=weight_observer, activation=act_observer)


class _AddFakeQuantInput(nn.Cell): class _AddFakeQuantInput(nn.Cell):
""" """
Add FakeQuant OP at input of the network. Only support one input case. Add FakeQuant OP at input of the network. Only support one input case.
@@ -48,7 +89,8 @@ class _AddFakeQuantInput(nn.Cell):


def __init__(self, network, quant_delay=0): def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInput, self).__init__(auto_prefix=False) super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input.') self.fake_quant_input.update_parameters_name('fake_quant_input.')
self.network = network self.network = network


@@ -66,14 +108,14 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
def __init__(self, subcell, **kwargs): def __init__(self, subcell, **kwargs):
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
self.subcell = subcell self.subcell = subcell
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
num_bits=kwargs["num_bits"],
quant_delay=kwargs["quant_delay"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
narrow_range=kwargs["narrow_range"])
self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6,
max_init=6,
ema=True,
quant_dtype=kwargs["quant_dtype"],
quant_delay=kwargs["quant_delay"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
narrow_range=kwargs["narrow_range"])


def construct(self, *data): def construct(self, *data):
output = self.subcell(*data) output = self.subcell(*data)
@@ -93,8 +135,8 @@ class ConvertToQuantNetwork:
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit")
self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit")
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
@@ -103,6 +145,11 @@ class ConvertToQuantNetwork:
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense} quant.DenseBnAct: self._convert_dense}
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
quant_dtype=kwargs["quant_dtype"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
narrow_range=kwargs["narrow_range"])


def _convert_op_name(self, name): def _convert_op_name(self, name):
pattern = re.compile(r'([A-Z]{1})') pattern = re.compile(r'([A-Z]{1})')
@@ -149,7 +196,7 @@ class ConvertToQuantNetwork:
for name, prim_op in add_list: for name, prim_op in add_list:
prefix = name prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op, add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits,
quant_dtype=self.act_dtype,
quant_delay=self.act_qdelay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
@@ -180,15 +227,12 @@ class ConvertToQuantNetwork:
group=conv_inner.group, group=conv_inner.group,
eps=bn_inner.eps, eps=bn_inner.eps,
momentum=bn_inner.momentum, momentum=bn_inner.momentum,
quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
fake=True,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range,
has_bias=conv_inner.has_bias, has_bias=conv_inner.has_bias,
bias_init=conv_inner.bias_init)
bias_init=conv_inner.bias_init,
freeze_bn=self.freeze_bn,
quant_config=self.quant_config,
quant_dtype=self.weight_dtype,
fake=True)
# change original network BatchNormal OP parameters to quant network # change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta conv_inner.beta = subcell.batchnorm.beta
@@ -209,13 +253,10 @@ class ConvertToQuantNetwork:
group=conv_inner.group, group=conv_inner.group,
eps=bn_inner.eps, eps=bn_inner.eps,
momentum=bn_inner.momentum, momentum=bn_inner.momentum,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range,
has_bias=conv_inner.has_bias, has_bias=conv_inner.has_bias,
bias_init=conv_inner.bias_init)
bias_init=conv_inner.bias_init,
quant_config=self.quant_config,
quant_dtype=self.weight_dtype)
# change original network BatchNormal OP parameters to quant network # change original network BatchNormal OP parameters to quant network
conv_inner.batchnorm.gamma = subcell.batchnorm.gamma conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
conv_inner.batchnorm.beta = subcell.batchnorm.beta conv_inner.batchnorm.beta = subcell.batchnorm.beta
@@ -234,11 +275,8 @@ class ConvertToQuantNetwork:
dilation=conv_inner.dilation, dilation=conv_inner.dilation,
group=conv_inner.group, group=conv_inner.group,
has_bias=conv_inner.has_bias, has_bias=conv_inner.has_bias,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
quant_config=self.quant_config,
quant_dtype=self.weight_dtype)
# change original network Conv2D OP parameters to quant network # change original network Conv2D OP parameters to quant network
conv_inner.weight = subcell.conv.weight conv_inner.weight = subcell.conv.weight
if subcell.conv.has_bias: if subcell.conv.has_bias:
@@ -249,7 +287,7 @@ class ConvertToQuantNetwork:
elif subcell.after_fake: elif subcell.after_fake:
subcell.has_act = True subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_dtype=self.act_dtype,
quant_delay=self.act_qdelay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
@@ -264,11 +302,8 @@ class ConvertToQuantNetwork:
dense_inner = quant.DenseQuant(dense_inner.in_channels, dense_inner = quant.DenseQuant(dense_inner.in_channels,
dense_inner.out_channels, dense_inner.out_channels,
has_bias=dense_inner.has_bias, has_bias=dense_inner.has_bias,
num_bits=self.weight_bits,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
quant_config=self.quant_config,
quant_dtype=self.weight_dtype)
# change original network Dense OP parameters to quant network # change original network Dense OP parameters to quant network
dense_inner.weight = subcell.dense.weight dense_inner.weight = subcell.dense.weight
if subcell.dense.has_bias: if subcell.dense.has_bias:
@@ -279,7 +314,7 @@ class ConvertToQuantNetwork:
elif subcell.after_fake: elif subcell.after_fake:
subcell.has_act = True subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_dtype=self.act_dtype,
quant_delay=self.act_qdelay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
@@ -291,11 +326,8 @@ class ConvertToQuantNetwork:
if act_class not in _ACTIVATION_MAP: if act_class not in _ACTIVATION_MAP:
raise ValueError("Unsupported activation in auto quant: ", act_class) raise ValueError("Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](activation=activation, return _ACTIVATION_MAP[act_class](activation=activation,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
quant_config=self.quant_config,
quant_dtype=self.act_dtype)




class ExportToQuantInferNetwork: class ExportToQuantInferNetwork:
@@ -523,7 +555,7 @@ def convert_quant_network(network,
bn_fold=True, bn_fold=True,
freeze_bn=10000000, freeze_bn=10000000,
quant_delay=(0, 0), quant_delay=(0, 0),
num_bits=(8, 8),
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
per_channel=(False, False), per_channel=(False, False),
symmetric=(False, False), symmetric=(False, False),
narrow_range=(False, False) narrow_range=(False, False)
@@ -537,8 +569,9 @@ def convert_quant_network(network,
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0) eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
element represent weights and second element represent data flow.
Default: (QuantDtype.INT8, QuantDtype.INT8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False) and second element represent data flow. Default: (False, False)
@@ -561,7 +594,7 @@ def convert_quant_network(network,
return value return value


quant_delay = convert2list("quant delay", quant_delay) quant_delay = convert2list("quant delay", quant_delay)
num_bits = convert2list("num bits", num_bits)
quant_dtype = convert2list("quant dtype", quant_dtype)
per_channel = convert2list("per channel", per_channel) per_channel = convert2list("per channel", per_channel)
symmetric = convert2list("symmetric", symmetric) symmetric = convert2list("symmetric", symmetric)
narrow_range = convert2list("narrow range", narrow_range) narrow_range = convert2list("narrow range", narrow_range)
@@ -573,7 +606,7 @@ def convert_quant_network(network,
quant_delay=quant_delay, quant_delay=quant_delay,
bn_fold=bn_fold, bn_fold=bn_fold,
freeze_bn=freeze_bn, freeze_bn=freeze_bn,
num_bits=num_bits,
quant_dtype=quant_dtype,
per_channel=per_channel, per_channel=per_channel,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range)


+ 15
- 16
model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py View File

@@ -17,12 +17,14 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
from mindspore.train.quant import quant


_ema_decay = 0.999 _ema_decay = 0.999
_symmetric = True _symmetric = True
_fake = True _fake = True
_per_channel = True _per_channel = True
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))




def _weight_variable(shape, factor=0.01): def _weight_variable(shape, factor=0.01):
@@ -89,7 +91,7 @@ class ConvBNReLU(nn.Cell):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
group=groups, fake=_fake, quant_config=_quant_config)
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)


@@ -124,13 +126,12 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
symmetric=_symmetric,
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config,
kernel_size=1, stride=1, pad_mode='same', padding=0), kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=1, kernel_size=1, stride=1,
pad_mode='same', padding=0) pad_mode='same', padding=0)


@@ -142,16 +143,15 @@ class ResidualBlock(nn.Cell):


if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=stride, kernel_size=1, stride=stride,
pad_mode='same', padding=0), pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
symmetric=False)
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
fake=_fake, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=\
_quant_config,
kernel_size=1, kernel_size=1,
stride=stride, stride=stride,
pad_mode='same', pad_mode='same',
@@ -235,9 +235,8 @@ class ResNet(nn.Cell):


self.mean = P.ReduceMean(keep_dims=True) self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)
self.output_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)


def _make_layer(self, block, layer_num, in_channel, out_channel, stride): def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
""" """


+ 17
- 24
tests/st/quantization/resnet50_quant/resnet_quant_manual.py View File

@@ -13,20 +13,19 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""ResNet.""" """ResNet."""

import numpy as np import numpy as np

import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init
from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant

from mindspore import Tensor
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
from mindspore.train.quant import quant


_ema_decay = 0.999 _ema_decay = 0.999
_symmetric = True _symmetric = True
_fake = True _fake = True
_per_channel = True _per_channel = True
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))




def _weight_variable(shape, factor=0.01): def _weight_variable(shape, factor=0.01):
@@ -93,7 +92,7 @@ class ConvBNReLU(nn.Cell):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
group=groups, fake=_fake, quant_config=_quant_config)
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)


@@ -128,14 +127,12 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
symmetric=_symmetric,
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config,
kernel_size=1, stride=1, pad_mode='same', padding=0), kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMax(
ema=True, ema_decay=_ema_decay, symmetric=False)
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=1, kernel_size=1, stride=1,
pad_mode='same', padding=0) pad_mode='same', padding=0)


@@ -147,16 +144,15 @@ class ResidualBlock(nn.Cell):


if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=stride, kernel_size=1, stride=stride,
pad_mode='same', padding=0), pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
symmetric=False)
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
fake=_fake, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=\
_quant_config,
kernel_size=1, kernel_size=1,
stride=stride, stride=stride,
pad_mode='same', pad_mode='same',
@@ -212,8 +208,7 @@ class ResNet(nn.Cell):
super(ResNet, self).__init__() super(ResNet, self).__init__()


if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError(
"the length of layer_num, in_channels, out_channels list must be 4!")
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")


self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2) self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
@@ -241,10 +236,8 @@ class ResNet(nn.Cell):


self.mean = P.ReduceMean(keep_dims=True) self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)
self.output_fake = nn.FakeQuantWithMinMax(
ema=True, ema_decay=_ema_decay)
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)


# init weights # init weights
self._initialize_weights() self._initialize_weights()


Loading…
Cancel
Save