|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
"""Quantization aware training.""" |
|
|
|
|
|
|
|
from functools import partial |
|
|
|
from collections import namedtuple |
|
|
|
import numpy as np |
|
|
|
from mindspore import nn |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
@@ -34,7 +35,7 @@ from ...ops.operations import _quant_ops as Q |
|
|
|
__all__ = [ |
|
|
|
'Conv2dBnAct', |
|
|
|
'DenseBnAct', |
|
|
|
'FakeQuantWithMinMax', |
|
|
|
'FakeQuantWithMinMaxObserver', |
|
|
|
'Conv2dBnFoldQuant', |
|
|
|
'Conv2dBnWithoutFoldQuant', |
|
|
|
'Conv2dQuant', |
|
|
|
@@ -422,14 +423,14 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
"""Initialize FakeQuantWithMinMax layer""" |
|
|
|
"""Initialize FakeQuantWithMinMaxObserver""" |
|
|
|
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, |
|
|
|
symmetric=symmetric, narrow_range=narrow_range, |
|
|
|
num_channels=num_channels) |
|
|
|
Validator.check_type("min_init", min_init, [int, float]) |
|
|
|
Validator.check_type("max_init", max_init, [int, float]) |
|
|
|
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) |
|
|
|
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) |
|
|
|
Validator.check_non_negative_int(quant_delay, 'quant_delay') |
|
|
|
self.min_init = min_init |
|
|
|
self.max_init = max_init |
|
|
|
self.quant_dtype = quant_dtype |
|
|
|
@@ -498,119 +499,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class FakeQuantWithMinMax(Cell): |
|
|
|
r""" |
|
|
|
Quantization aware op. This OP provides the fake quantization observer function on data with min and max. |
|
|
|
|
|
|
|
Args: |
|
|
|
min_init (int, float): The initialized min value. Default: -6. |
|
|
|
max_init (int, float): The initialized max value. Default: 6. |
|
|
|
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. |
|
|
|
channel_axis (int): Quantization by channel axis. Default: 1. |
|
|
|
num_channels (int): declarate the min and max channel size, Default: 1. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input of FakeQuantWithMinMax. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, with the same type and shape as the `x`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> fake_quant = FakeQuantWithMinMax() |
|
|
|
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) |
|
|
|
>>> result = fake_quant(input_x) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
ema_decay=0.999, |
|
|
|
per_channel=False, |
|
|
|
channel_axis=1, |
|
|
|
num_channels=1, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
"""Initialize FakeQuantWithMinMax layer""" |
|
|
|
super(FakeQuantWithMinMax, self).__init__() |
|
|
|
Validator.check_type("min_init", min_init, [int, float]) |
|
|
|
Validator.check_type("max_init", max_init, [int, float]) |
|
|
|
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) |
|
|
|
Validator.check_non_negative_int(quant_delay, 'quant_delay') |
|
|
|
self.min_init = min_init |
|
|
|
self.max_init = max_init |
|
|
|
self.num_bits = num_bits |
|
|
|
self.ema = ema |
|
|
|
self.ema_decay = ema_decay |
|
|
|
self.per_channel = per_channel |
|
|
|
self.num_channels = num_channels |
|
|
|
self.channel_axis = channel_axis |
|
|
|
self.quant_delay = quant_delay |
|
|
|
self.symmetric = symmetric |
|
|
|
self.narrow_range = narrow_range |
|
|
|
self.is_ascend = context.get_context('device_target') == "Ascend" |
|
|
|
|
|
|
|
# init tensor min and max for fake quant op |
|
|
|
if self.per_channel: |
|
|
|
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32) |
|
|
|
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32) |
|
|
|
else: |
|
|
|
min_array = np.array([self.min_init]).astype(np.float32) |
|
|
|
max_array = np.array([self.max_init]).astype(np.float32) |
|
|
|
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) |
|
|
|
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) |
|
|
|
|
|
|
|
# init fake quant relative op |
|
|
|
if self.per_channel: |
|
|
|
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) |
|
|
|
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) |
|
|
|
else: |
|
|
|
quant_fun = Q.FakeQuantPerLayer |
|
|
|
ema_fun = Q.MinMaxUpdatePerLayer |
|
|
|
|
|
|
|
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) |
|
|
|
if self.is_ascend: |
|
|
|
self.fake_quant_train = quant_fun(num_bits=self.num_bits, |
|
|
|
symmetric=self.symmetric, |
|
|
|
narrow_range=self.narrow_range, |
|
|
|
quant_delay=self.quant_delay) |
|
|
|
self.fake_quant_infer = self.fake_quant_train |
|
|
|
else: |
|
|
|
quant_fun = partial(quant_fun, |
|
|
|
ema=self.ema, |
|
|
|
ema_decay=ema_decay, |
|
|
|
num_bits=self.num_bits, |
|
|
|
symmetric=self.symmetric, |
|
|
|
narrow_range=self.narrow_range, |
|
|
|
quant_delay=self.quant_delay) |
|
|
|
self.fake_quant_train = quant_fun(training=True) |
|
|
|
self.fake_quant_infer = quant_fun(training=False) |
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
|
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ |
|
|
|
'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range, |
|
|
|
self.ema, self.ema_decay, self.per_channel, |
|
|
|
self.channel_axis, self.num_channels, self.quant_delay, |
|
|
|
self.min_init, self.max_init) |
|
|
|
return s |
|
|
|
QuantConfig = namedtuple("QuantConfig", ['weight', 'activation']) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
if self.training: |
|
|
|
min_up, max_up = self.ema_update(x, self.minq, self.maxq) |
|
|
|
P.Assign()(self.minq, min_up) |
|
|
|
P.Assign()(self.maxq, max_up) |
|
|
|
out = self.fake_quant_train(x, self.minq, self.maxq) |
|
|
|
else: |
|
|
|
out = self.fake_quant_infer(x, self.minq, self.maxq) |
|
|
|
return out |
|
|
|
quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver) |
|
|
|
|
|
|
|
|
|
|
|
class Conv2dBnFoldQuant(Cell): |
|
|
|
@@ -641,12 +532,9 @@ class Conv2dBnFoldQuant(Cell): |
|
|
|
mean vector. Default: 'zeros'. |
|
|
|
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
variance vector. Default: 'ones'. |
|
|
|
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMax op. Default: True. |
|
|
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): The Quantization delay parameters according to the global step. Default: 0. |
|
|
|
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
freeze_bn (int): The quantization freeze BatchNormal op is according to the global step. Default: 100000. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
@@ -680,11 +568,8 @@ class Conv2dBnFoldQuant(Cell): |
|
|
|
mean_init='zeros', |
|
|
|
var_init='ones', |
|
|
|
fake=True, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0, |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8, |
|
|
|
freeze_bn=100000): |
|
|
|
"""Initialize Conv2dBnFoldQuant layer""" |
|
|
|
super(Conv2dBnFoldQuant, self).__init__() |
|
|
|
@@ -699,13 +584,10 @@ class Conv2dBnFoldQuant(Cell): |
|
|
|
self.eps = eps |
|
|
|
self.momentum = momentum |
|
|
|
self.has_bias = has_bias |
|
|
|
self.quant_delay = quant_delay |
|
|
|
self.freeze_bn = freeze_bn |
|
|
|
self.fake = fake |
|
|
|
self.num_bits = num_bits |
|
|
|
self.per_channel = per_channel |
|
|
|
self.symmetric = symmetric |
|
|
|
self.narrow_range = narrow_range |
|
|
|
self.quant_config = quant_config |
|
|
|
self.quant_dtype = quant_dtype |
|
|
|
self.is_gpu = context.get_context('device_target') == "GPU" |
|
|
|
|
|
|
|
# initialize convolution op and Parameter |
|
|
|
@@ -745,16 +627,12 @@ class Conv2dBnFoldQuant(Cell): |
|
|
|
requires_grad=False) |
|
|
|
|
|
|
|
# initialize fake ops |
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
self.fake_quant_weight = quant_config.weight(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
per_channel=per_channel, |
|
|
|
channel_axis=channel_axis, |
|
|
|
num_channels=out_channels, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) |
|
|
|
self.correct_mul = Q.CorrectionMul(channel_axis) |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
@@ -777,7 +655,7 @@ class Conv2dBnFoldQuant(Cell): |
|
|
|
self.pad_mode, self.padding, self.dilation, |
|
|
|
self.group, |
|
|
|
self.fake, self.freeze_bn, self.momentum, |
|
|
|
self.quant_delay) |
|
|
|
self.fake_quant_weight.quant_delay) |
|
|
|
return s |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
@@ -836,11 +714,8 @@ class Conv2dBnWithoutFoldQuant(Cell): |
|
|
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. |
|
|
|
Default: 'normal'. |
|
|
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. |
|
|
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. |
|
|
|
@@ -868,11 +743,8 @@ class Conv2dBnWithoutFoldQuant(Cell): |
|
|
|
momentum=0.997, |
|
|
|
weight_init='normal', |
|
|
|
bias_init='zeros', |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(Conv2dBnWithoutFoldQuant, self).__init__() |
|
|
|
if isinstance(kernel_size, int): |
|
|
|
self.kernel_size = (kernel_size, kernel_size) |
|
|
|
@@ -886,7 +758,6 @@ class Conv2dBnWithoutFoldQuant(Cell): |
|
|
|
self.pad_mode = pad_mode |
|
|
|
self.padding = padding |
|
|
|
self.group = group |
|
|
|
self.quant_delay = quant_delay |
|
|
|
|
|
|
|
self.bias_add = P.BiasAdd() |
|
|
|
if Validator.check_bool(has_bias): |
|
|
|
@@ -917,16 +788,12 @@ class Conv2dBnWithoutFoldQuant(Cell): |
|
|
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size] |
|
|
|
channel_axis = 0 |
|
|
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') |
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
self.fake_quant_weight = quant_config.weight(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
per_channel=per_channel, |
|
|
|
channel_axis=channel_axis, |
|
|
|
num_channels=out_channels, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
@@ -942,7 +809,7 @@ class Conv2dBnWithoutFoldQuant(Cell): |
|
|
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \ |
|
|
|
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, |
|
|
|
self.pad_mode, self.padding, self.dilation, self.group, |
|
|
|
self.has_bias, self.quant_delay) |
|
|
|
self.has_bias, self.fake_quant_weight.quant_delay) |
|
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
@@ -966,11 +833,8 @@ class Conv2dQuant(Cell): |
|
|
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. |
|
|
|
Default: 'normal'. |
|
|
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. |
|
|
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. |
|
|
|
@@ -996,11 +860,8 @@ class Conv2dQuant(Cell): |
|
|
|
has_bias=False, |
|
|
|
weight_init='normal', |
|
|
|
bias_init='zeros', |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(Conv2dQuant, self).__init__() |
|
|
|
if isinstance(kernel_size, int): |
|
|
|
self.kernel_size = (kernel_size, kernel_size) |
|
|
|
@@ -1014,7 +875,6 @@ class Conv2dQuant(Cell): |
|
|
|
self.pad_mode = pad_mode |
|
|
|
self.padding = padding |
|
|
|
self.group = group |
|
|
|
self.quant_delay = quant_delay |
|
|
|
|
|
|
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size] |
|
|
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') |
|
|
|
@@ -1033,16 +893,12 @@ class Conv2dQuant(Cell): |
|
|
|
stride=self.stride, |
|
|
|
dilation=self.dilation, |
|
|
|
group=self.group) |
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
self.fake_quant_weight = quant_config.weight(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
per_channel=per_channel, |
|
|
|
channel_axis=0, |
|
|
|
num_channels=out_channels, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
weight = self.fake_quant_weight(self.weight) |
|
|
|
@@ -1056,7 +912,7 @@ class Conv2dQuant(Cell): |
|
|
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \ |
|
|
|
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, |
|
|
|
self.pad_mode, self.padding, self.dilation, self.group, |
|
|
|
self.has_bias, self.quant_delay) |
|
|
|
self.has_bias, self.fake_quant_weight.quant_delay) |
|
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
@@ -1075,11 +931,8 @@ class DenseQuant(Cell): |
|
|
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. |
|
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. |
|
|
|
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. |
|
|
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. |
|
|
|
@@ -1093,19 +946,15 @@ class DenseQuant(Cell): |
|
|
|
>>> result = dense_quant(input_x) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
weight_init='normal', |
|
|
|
bias_init='zeros', |
|
|
|
has_bias=True, |
|
|
|
activation=None, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
def __init__(self, |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
weight_init='normal', |
|
|
|
bias_init='zeros', |
|
|
|
has_bias=True, |
|
|
|
activation=None, |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(DenseQuant, self).__init__() |
|
|
|
self.in_channels = Validator.check_positive_int(in_channels) |
|
|
|
self.out_channels = Validator.check_positive_int(out_channels) |
|
|
|
@@ -1132,16 +981,12 @@ class DenseQuant(Cell): |
|
|
|
|
|
|
|
self.activation = get_activation(activation) |
|
|
|
self.activation_flag = self.activation is not None |
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
self.fake_quant_weight = quant_config.weight(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
per_channel=per_channel, |
|
|
|
channel_axis=0, |
|
|
|
num_channels=out_channels, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
"""Use operators to construct the Dense layer.""" |
|
|
|
@@ -1179,16 +1024,13 @@ class ActQuant(_QuantActivation): |
|
|
|
Quantization aware training activation function. |
|
|
|
|
|
|
|
Add the fake quant op to the end of activation op, by which the output of activation op will be truncated. |
|
|
|
Please check `FakeQuantWithMinMax` for more details. |
|
|
|
Please check `FakeQuantWithMinMaxObserver` or other observer for more details. |
|
|
|
|
|
|
|
Args: |
|
|
|
activation (Cell): Activation cell class. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global steps. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input of ReLU6Quant. |
|
|
|
@@ -1205,21 +1047,14 @@ class ActQuant(_QuantActivation): |
|
|
|
def __init__(self, |
|
|
|
activation, |
|
|
|
ema_decay=0.999, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(ActQuant, self).__init__() |
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=0, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.act = activation |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
@@ -1240,11 +1075,8 @@ class LeakyReLUQuant(_QuantActivation): |
|
|
|
Args: |
|
|
|
activation (Cell): Activation cell class. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input of LeakyReLUQuant. |
|
|
|
@@ -1261,30 +1093,19 @@ class LeakyReLUQuant(_QuantActivation): |
|
|
|
def __init__(self, |
|
|
|
activation, |
|
|
|
ema_decay=0.999, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(LeakyReLUQuant, self).__init__() |
|
|
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act_before = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.fake_quant_act_after = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
if issubclass(activation.__class__, nn.LeakyReLU): |
|
|
|
self.act = activation |
|
|
|
else: |
|
|
|
@@ -1309,11 +1130,8 @@ class HSwishQuant(_QuantActivation): |
|
|
|
Args: |
|
|
|
activation (Cell): Activation cell class. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input of HSwishQuant. |
|
|
|
@@ -1330,30 +1148,19 @@ class HSwishQuant(_QuantActivation): |
|
|
|
def __init__(self, |
|
|
|
activation, |
|
|
|
ema_decay=0.999, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(HSwishQuant, self).__init__() |
|
|
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act_before = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.fake_quant_act_after = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
if issubclass(activation.__class__, nn.HSwish): |
|
|
|
self.act = activation |
|
|
|
else: |
|
|
|
@@ -1378,11 +1185,8 @@ class HSigmoidQuant(_QuantActivation): |
|
|
|
Args: |
|
|
|
activation (Cell): Activation cell class. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input of HSigmoidQuant. |
|
|
|
@@ -1399,30 +1203,19 @@ class HSigmoidQuant(_QuantActivation): |
|
|
|
def __init__(self, |
|
|
|
activation, |
|
|
|
ema_decay=0.999, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(HSigmoidQuant, self).__init__() |
|
|
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act_before = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.fake_quant_act_after = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
if issubclass(activation.__class__, nn.HSigmoid): |
|
|
|
self.act = activation |
|
|
|
else: |
|
|
|
@@ -1446,11 +1239,8 @@ class TensorAddQuant(Cell): |
|
|
|
|
|
|
|
Args: |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input of TensorAddQuant. |
|
|
|
@@ -1467,21 +1257,14 @@ class TensorAddQuant(Cell): |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
ema_decay=0.999, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(TensorAddQuant, self).__init__() |
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.add = P.TensorAdd() |
|
|
|
|
|
|
|
def construct(self, x1, x2): |
|
|
|
@@ -1498,11 +1281,8 @@ class MulQuant(Cell): |
|
|
|
|
|
|
|
Args: |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. |
|
|
|
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. |
|
|
|
symmetric (bool): The quantization algorithm is symmetric or not. Default: False. |
|
|
|
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. |
|
|
|
quant_delay (int): Quantization delay parameters according to the global step. Default: 0. |
|
|
|
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input of MulQuant. |
|
|
|
@@ -1510,25 +1290,23 @@ class MulQuant(Cell): |
|
|
|
Outputs: |
|
|
|
Tensor, with the same type and shape as the `x`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> mul_quant = nn.MulQuant() |
|
|
|
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) |
|
|
|
>>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) |
|
|
|
>>> result = mul_quant(input_x, input_y) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
ema_decay=0.999, |
|
|
|
per_channel=False, |
|
|
|
num_bits=8, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
quant_delay=0): |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
super(MulQuant, self).__init__() |
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
per_channel=per_channel, |
|
|
|
num_bits=num_bits, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range, |
|
|
|
quant_delay=quant_delay) |
|
|
|
self.fake_quant_act = quant_config.activation(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.mul = P.Mul() |
|
|
|
|
|
|
|
def construct(self, x1, x2): |
|
|
|
|