| @@ -17,7 +17,7 @@ Layer. | |||
| The high-level components(Cells) used to construct the neural network. | |||
| """ | |||
| from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU | |||
| from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish | |||
| from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm | |||
| from .container import SequentialCell, CellList | |||
| from .conv import Conv2d, Conv2dTranspose | |||
| @@ -26,8 +26,9 @@ from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradi | |||
| from .embedding import Embedding | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU', | |||
| 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'ELU', | |||
| __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', | |||
| 'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', | |||
| 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', | |||
| 'SequentialCell', 'CellList', | |||
| 'Conv2d', 'Conv2dTranspose', | |||
| 'LSTM', | |||
| @@ -0,0 +1,703 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Aware quantization.""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import check_int_positive, check_bool, twice | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer.conv import _Conv | |||
| from mindspore.nn.layer.activation import get_activation | |||
| __all__ = [ | |||
| 'FakeQuantWithMinMax', | |||
| 'Conv2dBatchNormQuant', | |||
| 'Conv2dQuant', | |||
| 'DenseQuant', | |||
| 'ReLUQuant', | |||
| 'ReLU6Quant', | |||
| 'HSwishQuant', | |||
| 'HSigmoidQuant', | |||
| 'TensorAddQuant', | |||
| ] | |||
| class FakeQuantWithMinMax(Cell): | |||
| r""" | |||
| Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max. | |||
| Args: | |||
| min_init (int, list): The dimension of channel or 1(layer). Default: -6. | |||
| max_init (int, list): The dimension of channel or 1(layer). Default: 6. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. | |||
| per_channel (bool): Quantization by layer or channel. Default: False. | |||
| channel_size (int): declarate the min and max channel size, Default: 1. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of FakeQuantWithMinMax. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| """ | |||
| def __init__(self, | |||
| min_init=-6, | |||
| max_init=6, | |||
| num_bits=8, | |||
| ema=False, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| channel_size=1, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| super(FakeQuantWithMinMax, self).__init__() | |||
| self.min_init = min_init | |||
| self.num_bits = num_bits | |||
| self.max_init = max_init | |||
| self.ema = ema | |||
| self.ema_decay = ema_decay | |||
| self.per_channel = per_channel | |||
| self.channel_size = channel_size | |||
| self.quant_delay = quant_delay | |||
| self.symmetric = symmetric | |||
| self.narrow_range = narrow_range | |||
| if per_channel: | |||
| min_array = np.array([self.min_init for i in range( | |||
| 0, self.channel_size)]).astype(np.float32) | |||
| max_array = np.array([self.max_init for i in range( | |||
| 0, self.channel_size)]).astype(np.float32) | |||
| self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=True) | |||
| self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=ema_decay, | |||
| quant_delay=quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=False) | |||
| else: | |||
| min_array = np.array([min_init]).reshape(1).astype(np.float32) | |||
| max_array = np.array([max_init]).reshape(1).astype(np.float32) | |||
| self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| quant_delay=self.quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=True) | |||
| self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=ema_decay, | |||
| quant_delay=quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=False) | |||
| self.min = Parameter( | |||
| Tensor(min_array), name='quant_min', requires_grad=False) | |||
| self.max = Parameter( | |||
| Tensor(max_array), name='quant_max', requires_grad=False) | |||
| def extend_repr(self): | |||
| s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( | |||
| self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, | |||
| self.quant_delay) | |||
| return s | |||
| def construct(self, x): | |||
| if self.training: | |||
| out = self.fake_quant_train(x, self.min, self.max) | |||
| else: | |||
| out = self.fake_quant_infer(x, self.min, self.max) | |||
| return out | |||
| class Conv2dBatchNormQuant(Cell): | |||
| r""" | |||
| 2D convolution with BatchNormal op folded layer. | |||
| For a more Detailed overview of Conv2d op. | |||
| Args: | |||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||
| kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. | |||
| stride (int): Specifies stride for all spatial dimensions with the same value. | |||
| pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||
| padding: (int): Implicit paddings on both sides of the input. Default: 0. | |||
| eps (int): Parameters for BatchNormal. Default: 1e-5. | |||
| momentum (int): Parameters for BatchNormal op. Default: 0.9. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| convolution kernel. Default: 'None'. | |||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| beta vector. Default: 'None'. | |||
| gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| gamma vector. Default: 'None'. | |||
| mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| mean vector. Default: 'None'. | |||
| var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| variance vector. Default: 'None'. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | |||
| fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| pad_mode, | |||
| padding=0, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| weight_init=None, | |||
| beta_init=None, | |||
| gamma_init=None, | |||
| mean_init=None, | |||
| var_init=None, | |||
| group=1, | |||
| quant_delay=0, | |||
| freeze_bn=100000, | |||
| fake=True, | |||
| num_bits=8, | |||
| per_channel=False, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| super(Conv2dBatchNormQuant, self).__init__() | |||
| self.stride = stride | |||
| self.conv = P.Conv2D(out_channel=out_channels, | |||
| kernel_size=kernel_size, | |||
| mode=1, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=stride, | |||
| dilation=1, | |||
| group=group) | |||
| self.fake = fake | |||
| self.freeze_bn = freeze_bn | |||
| if isinstance(kernel_size, int): | |||
| kernel_size = (kernel_size, kernel_size) | |||
| if weight_init is None: | |||
| weight_init = initializer( | |||
| 'normal', [out_channels, in_channels // group, *kernel_size]) | |||
| self.weight = Parameter(weight_init, name='weight') | |||
| if gamma_init is None: | |||
| gamma_init = initializer('ones', [out_channels]) | |||
| self.gamma = Parameter(gamma_init, name='gamma') | |||
| if beta_init is None: | |||
| beta_init = initializer('zeros', [out_channels]) | |||
| self.beta = Parameter(beta_init, name='beta') | |||
| if mean_init is None: | |||
| mean_init = initializer('zeros', [out_channels]) | |||
| self.moving_mean = Parameter( | |||
| mean_init, name='moving_mean', requires_grad=False) | |||
| if var_init is None: | |||
| var_init = initializer('ones', [out_channels]) | |||
| self.moving_variance = Parameter( | |||
| var_init, name='moving_variance', requires_grad=False) | |||
| self.step = Parameter(initializer( | |||
| 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) | |||
| self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| per_channel=per_channel, | |||
| channel_size=out_channels, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, | |||
| momentum=momentum, | |||
| is_training=True, | |||
| freeze_bn=freeze_bn) | |||
| self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps, | |||
| momentum=momentum, | |||
| is_training=False, | |||
| freeze_bn=freeze_bn) | |||
| self.correct_mul = P.CorrectionMul() | |||
| self.relu = P.ReLU() | |||
| self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn) | |||
| self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.assignadd = P.AssignAdd() | |||
| def extend_repr(self): | |||
| s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn) | |||
| return s | |||
| def construct(self, x): | |||
| if self.training: | |||
| beta = self.beta | |||
| gamma = self.gamma | |||
| gmean = self.moving_mean | |||
| gvar = self.moving_variance | |||
| step = self.step | |||
| out_conv = self.conv(x, self.weight) | |||
| batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train( | |||
| out_conv, gmean, gvar, step) | |||
| # BN fold1 | |||
| weight = self.correct_mul(self.weight, gamma, running_std) | |||
| if self.fake: | |||
| weight = self.fake_quant_weight(weight) | |||
| out = self.conv(x, weight) | |||
| # BN fold2 | |||
| out = self.batchnorm_fold2( | |||
| out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step) | |||
| F.control_depend(out, self.assignadd(self.step, self.one)) | |||
| else: | |||
| step = self.step | |||
| out_conv = self.conv(x, self.weight) | |||
| batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( | |||
| out_conv, self.moving_mean, self.moving_variance, step) | |||
| weight = self.correct_mul(self.weight, self.gamma, running_std) | |||
| if self.fake: | |||
| weight = self.fake_quant_weight(weight) | |||
| out = self.conv(x, weight) | |||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, | |||
| running_std, running_mean, step) | |||
| return out | |||
| class Conv2dQuant(_Conv): | |||
| r""" | |||
| 2D convolution with fake quant op layer. | |||
| For a more Detailed overview of Conv2d op. | |||
| Args: | |||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||
| kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. | |||
| stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1. | |||
| pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||
| padding: (int): Implicit paddings on both sides of the input. Default: 0. | |||
| dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1. | |||
| group (int): Split filter into groups, `in_ channels` and `out_channels` should be | |||
| divisible by the number of groups. Default: 1. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||
| Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| pad_mode='same', | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| has_bias=False, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| quant_delay=0, | |||
| num_bits=8, | |||
| per_channel=False, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| kernel_size = twice(kernel_size) | |||
| super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, | |||
| group, has_bias, weight_init, bias_init) | |||
| self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1, | |||
| pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation, | |||
| group=self.group) | |||
| self.bias_add = P.BiasAdd() | |||
| if pad_mode not in ('valid', 'same', 'pad'): | |||
| raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' | |||
| + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') | |||
| self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| per_channel=per_channel, | |||
| channel_size=out_channels, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| def construct(self, x): | |||
| weight_q = self.fake_quant_weight(self.weight) | |||
| out = self.conv2d(x, weight_q) | |||
| if self.has_bias: | |||
| return self.bias_add(out, self.bias) | |||
| return out | |||
| class DenseQuant(Cell): | |||
| r""" | |||
| The fully connected layer with fake quant op. | |||
| For a more Detailed overview of Dense op. | |||
| Args: | |||
| in_channels (int): The dimension of the input space. | |||
| out_channels (int): The dimension of the output space. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||
| is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bias=True, | |||
| activation=None, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| per_channel=False, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| super(DenseQuant, self).__init__() | |||
| self.in_channels = check_int_positive(in_channels) | |||
| self.out_channels = check_int_positive(out_channels) | |||
| self.has_bias = check_bool(has_bias) | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||
| weight_init.shape()[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| self.weight = Parameter(initializer( | |||
| weight_init, [out_channels, in_channels]), name="weight") | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer( | |||
| bias_init, [out_channels]), name="bias") | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.bias_add = P.BiasAdd() | |||
| self.activation = get_activation(activation) | |||
| self.activation_flag = self.activation is not None | |||
| self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| per_channel=per_channel, | |||
| channel_size=out_channels, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| def construct(self, x): | |||
| """Use operators to construct to Dense layer.""" | |||
| output = self.fake_quant_weight(self.weight) | |||
| output = self.matmul(x, output) | |||
| if self.has_bias: | |||
| output = self.bias_add(output, self.bias) | |||
| if self.activation_flag: | |||
| return self.activation(output) | |||
| return output | |||
| def extend_repr(self): | |||
| """A pretty print for Dense layer.""" | |||
| str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format( | |||
| self.in_channels, self.out_channels, self.weight, self.has_bias) | |||
| if self.has_bias: | |||
| str_info = str_info + ', bias={}'.format(self.bias) | |||
| if self.activation_flag: | |||
| str_info = str_info + ', activation={}'.format(self.activation) | |||
| return str_info | |||
| class ReLUQuant(Cell): | |||
| r""" | |||
| ReLUQuant activation function. Add Fake Quant OP after Relu OP. | |||
| For a more Detailed overview of ReLU op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of ReLUQuant. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| super(ReLUQuant, self).__init__() | |||
| self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.relu = P.ReLU() | |||
| def construct(self, x): | |||
| x = self.relu(x) | |||
| x = self.fake_quant_act(x) | |||
| return x | |||
| class ReLU6Quant(Cell): | |||
| r""" | |||
| ReLU6Quant activation function. | |||
| Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op | |||
| Will climp the max range of the activation and the relu6 do the same operation. | |||
| For a more Detailed overview of ReLU6 op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of ReLU6Quant. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| """ | |||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, | |||
| narrow_range=False): | |||
| super(ReLU6Quant, self).__init__() | |||
| self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.relu6 = P.ReLU6() | |||
| def construct(self, x): | |||
| x = self.relu6(x) | |||
| x = self.fake_quant_act(x) | |||
| return x | |||
| class HSwishQuant(Cell): | |||
| r""" | |||
| HSwishQuant activation function. Add Fake Quant OP after HSwish OP. | |||
| For a more Detailed overview of HSwish op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of HSwishQuant. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| super(HSwishQuant, self).__init__() | |||
| self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.act = P.HSwish() | |||
| def construct(self, x): | |||
| x = self.fake_quant_act_before(x) | |||
| x = self.act(x) | |||
| x = self.fake_quant_act_after(x) | |||
| return x | |||
| class HSigmoidQuant(Cell): | |||
| r""" | |||
| HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. | |||
| For a more Detailed overview of HSigmoid op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of HSigmoidQuant. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| super(HSigmoidQuant, self).__init__() | |||
| self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.act = P.HSigmoid() | |||
| def construct(self, x): | |||
| x = self.fake_quant_act_before(x) | |||
| x = self.act(x) | |||
| x = self.fake_quant_act_after(x) | |||
| return x | |||
| class TensorAddQuant(Cell): | |||
| r""" | |||
| Add Fake Quant OP after TensorAdd OP. | |||
| For a more Detailed overview of TensorAdd op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of TensorAddQuant. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| super(TensorAddQuant, self).__init__() | |||
| self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x1, x2): | |||
| x = self.add(x1, x2) | |||
| x = self.fake_quant_act(x) | |||
| return x | |||
| @@ -234,7 +234,7 @@ class Tanh(Cell): | |||
| class GELU(Cell): | |||
| """ | |||
| r""" | |||
| Gaussian error linear unit activation function. | |||
| Applies GELU function to each element of the input. The input is a Tensor with any valid shape. | |||
| @@ -332,15 +332,74 @@ class PReLU(Cell): | |||
| return v | |||
| class HSwish(Cell): | |||
| r""" | |||
| rHard swish activation function. | |||
| Applies hswish-type activation element-wise. The input is a Tensor with any valid shape. | |||
| Hard swish is defined as: | |||
| .. math:: | |||
| \text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6}, | |||
| where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. | |||
| Inputs: | |||
| - **input_data** (Tensor) - The input of Hswish. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_data`. | |||
| """ | |||
| def __init__(self): | |||
| super(HSwish, self).__init__() | |||
| self.hswish = P.HSwish() | |||
| def construct(self, x): | |||
| return self.hswish(x) | |||
| class HSigmoid(Cell): | |||
| r""" | |||
| Hard sigmoid activation function. | |||
| Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape. | |||
| Hard sigmoid is defined as: | |||
| .. math:: | |||
| \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), | |||
| where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. | |||
| Inputs: | |||
| - **input_data** (Tensor) - The input of HSigmoid. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_data`. | |||
| """ | |||
| def __init__(self): | |||
| super(HSigmoid, self).__init__() | |||
| self.hsigmoid = P.HSigmoid() | |||
| def construct(self, x): | |||
| return self.hsigmoid(x) | |||
| _activation = { | |||
| 'softmax': Softmax, | |||
| 'logsoftmax': LogSoftmax, | |||
| 'relu': ReLU, | |||
| 'relu6': ReLU6, | |||
| 'tanh': Tanh, | |||
| 'gelu': GELU, | |||
| 'sigmoid': Sigmoid, | |||
| 'prelu': PReLU, | |||
| 'leakyrelu': LeakyReLU | |||
| 'leakyrelu': LeakyReLU, | |||
| 'hswish': HSwish, | |||
| 'hsigmoid': HSigmoid, | |||
| } | |||
| @@ -172,6 +172,28 @@ def get_bprop_relu6(self): | |||
| return bprop | |||
| @bprop_getters.register(P.HSwish) | |||
| def get_bprop_hswish(self): | |||
| """Grad definition for `HSwish` operation.""" | |||
| input_grad = G.HSwishGrad() | |||
| def bprop(x, out, dout): | |||
| dx = input_grad(dout, x) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.HSigmoid) | |||
| def get_bprop_hsigmoid(self): | |||
| """Grad definition for `HSigmoid` operation.""" | |||
| input_grad = G.HSigmoidGrad() | |||
| def bprop(x, out, dout): | |||
| dx = input_grad(dout, x) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.Elu) | |||
| def get_bprop_elu(self): | |||
| """Grad definition for `Elu` operation.""" | |||
| @@ -0,0 +1,82 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate bprop for aware quantization ops""" | |||
| from .. import operations as P | |||
| from .grad_base import bprop_getters | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| @bprop_getters.register(P.FakeQuantWithMinMax) | |||
| def get_bprop_fakequant_with_minmax(self): | |||
| """Generate bprop for FakeQuantWithMinMax""" | |||
| op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) | |||
| def bprop(x, x_min, x_max, out, dout): | |||
| dx = op(dout, x, x_min, x_max) | |||
| return dx, zeros_like(x_min), zeros_like(x_max) | |||
| return bprop | |||
| @bprop_getters.register(P.FakeQuantWithMinMaxPerChannel) | |||
| def get_bprop_fakequant_with_minmax_perchannel(self): | |||
| """Generate bprop for FakeQuantWithMinMaxPerChannel""" | |||
| op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) | |||
| def bprop(x, x_min, x_max, out, dout): | |||
| dx = op(dout, x, x_min, x_max) | |||
| return dx, zeros_like(x_min), zeros_like(x_max) | |||
| return bprop | |||
| @bprop_getters.register(P.BatchNormFold) | |||
| def get_bprop_batchnorm_fold(self): | |||
| """Generate bprop for BatchNormFold""" | |||
| op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn) | |||
| def bprop(x, mean, variance, global_step, out, dout): | |||
| dx = op(dout[0], dout[1], x, out[0], out[1], global_step) | |||
| return dx, zeros_like(mean), zeros_like(variance), zeros_like(global_step) | |||
| return bprop | |||
| @bprop_getters.register(P.CorrectionMul) | |||
| def get_bprop_correction_mul(self): | |||
| """Generate bprop for CorrectionMul""" | |||
| grad = P.CorrectionMulGrad() | |||
| def bprop(x, batch_std, running_std, out, dout): | |||
| dx, d_batch_std = grad(dout, x, batch_std, running_std) | |||
| return dx, d_batch_std, zeros_like(running_std) | |||
| return bprop | |||
| @bprop_getters.register(P.BatchNormFold2) | |||
| def get_bprop_batchnorm_fold2(self): | |||
| """Generate bprop for CorrectionAdd""" | |||
| op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn) | |||
| def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout): | |||
| d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std, | |||
| running_mean, global_step) | |||
| return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \ | |||
| zeros_like(global_step) | |||
| return bprop | |||
| @@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| LogSoftmax, | |||
| MaxPool, | |||
| AvgPool, Conv2DBackpropInput, | |||
| MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, | |||
| MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SmoothL1Loss, Softmax, | |||
| @@ -68,7 +68,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | |||
| ApplyRMSProp, ApplyCenteredRMSProp) | |||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey | |||
| from . import _quant_ops | |||
| from ._quant_ops import * | |||
| __all__ = [ | |||
| 'TensorAdd', | |||
| @@ -138,6 +139,8 @@ __all__ = [ | |||
| 'ReLU6', | |||
| 'Elu', | |||
| 'Sigmoid', | |||
| 'HSwish', | |||
| 'HSigmoid', | |||
| 'Tanh', | |||
| 'RandomChoiceWithMask', | |||
| 'ResizeBilinear', | |||
| @@ -241,4 +244,5 @@ __all__ = [ | |||
| "ApplyCenteredRMSProp" | |||
| ] | |||
| __all__.extend(_quant_ops.__all__) | |||
| __all__.sort() | |||
| @@ -805,6 +805,38 @@ class SigmoidGrad(PrimitiveWithInfer): | |||
| return out | |||
| class HSigmoidGrad(PrimitiveWithInfer): | |||
| """Gets the gradient of HSigmoid operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output']) | |||
| def infer_shape(self, y_grad_shape, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, y_grad_dtype, x_dtype): | |||
| validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) | |||
| validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32)) | |||
| return x_dtype | |||
| class HSwishGrad(PrimitiveWithInfer): | |||
| """Gets the gradient of HSwish operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output']) | |||
| def infer_shape(self, y_grad_shape, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, y_grad_dtype, x_dtype): | |||
| validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) | |||
| validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32)) | |||
| return x_dtype | |||
| class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): | |||
| """Computes the gradients of `SigmoidCrossEntropyWithLogits`.""" | |||
| @@ -0,0 +1,525 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0(the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http: // www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Operators for quantization.""" | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Rel, check_bool, check_int_positive, check_int | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ...common import dtype as mstype | |||
| __all__ = ["FakeQuantWithMinMax", | |||
| "FakeQuantWithMinMaxGrad", | |||
| "FakeQuantWithMinMaxPerChannel", | |||
| "FakeQuantWithMinMaxPerChannelGrad", | |||
| "BatchNormFold", | |||
| "BatchNormFoldGrad", | |||
| "CorrectionMul", | |||
| "CorrectionMulGrad", | |||
| "BatchNormFold2", | |||
| "BatchNormFold2Grad", | |||
| ] | |||
| class FakeQuantWithMinMax(PrimitiveWithInfer): | |||
| r""" | |||
| Simulate the quantize and dequantize operations in training time. | |||
| Args: | |||
| num_bits (int) : Number bits for aware quantilization. Default: 8. | |||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | |||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||
| quant_delay (int): Quantilization delay parameter. Before delay step in training time not update | |||
| simulate aware quantize funcion. After delay step in training time begin simulate the aware | |||
| quantize funcion. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| training (bool): Training the network or not. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. | |||
| - **min** (Tensor) : Value of the min range of the input data x. | |||
| - **max** (Tensor) : Value of the max range of the input data x. | |||
| Outputs: | |||
| - Tensor: Simulate quantize tensor of x. | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||
| >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | |||
| >>> max_tensor = Tensor(np.array([6]), mstype.float32) | |||
| >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||
| """ | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, | |||
| training=True): | |||
| """init FakeQuantWithMinMax OP""" | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError("Attr \'num_bits\' is not support.") | |||
| if ema and not ema_decay: | |||
| raise ValueError( | |||
| "Attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = check_bool(ema) | |||
| self.symmetric = check_bool(symmetric) | |||
| self.narrow_range = check_bool(narrow_range) | |||
| self.training = check_bool(training) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) | |||
| self.num_bits = check_int_positive(num_bits) | |||
| self.quant_delay = check_int(quant_delay) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | |||
| outputs=['out']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x shape", len(x_shape), 1, Rel.GT) | |||
| validator.check("min shape", min_shape, "max shape", max_shape) | |||
| validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) | |||
| validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| validator.check_typename( | |||
| "x type", x_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename("min type", min_type, | |||
| (mstype.float16, mstype.float32)) | |||
| validator.check_typename("max type", max_type, | |||
| (mstype.float16, mstype.float32)) | |||
| return x_type | |||
| class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): | |||
| """Performs grad of FakeQuantWithMinMax operation.""" | |||
| support_quant_bit = [4, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, quant_delay=0): | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError("Attr \'num_bits\' is not support.") | |||
| self.quant_delay = check_int(quant_delay) | |||
| self.num_bits = check_int_positive(num_bits) | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], | |||
| outputs=['dx']) | |||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | |||
| validator.check("dout shape", dout_shape, "x shape", x_shape) | |||
| validator.check("min shape", min_shape, "max shape", max_shape) | |||
| validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) | |||
| validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) | |||
| return dout_shape | |||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | |||
| validator.check_typename( | |||
| "dout type", dout_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename( | |||
| "x type", x_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename("min type", min_type, | |||
| (mstype.float16, mstype.float32)) | |||
| validator.check_typename("max type", max_type, | |||
| (mstype.float16, mstype.float32)) | |||
| return dout_type | |||
| class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||
| r""" | |||
| Simulate the quantize and dequantize operations in training time base on per channel. | |||
| Args: | |||
| num_bits (int) : Number bits to quantilization. Default: 8. | |||
| ema (bool): Use EMA algorithm update tensor min and tensor max. Default: False. | |||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||
| quant_delay (int): Quantilization delay parameter. Before delay step in training time not | |||
| update the weight data to simulate quantize operation. After delay step in training time | |||
| begin simulate the quantize operation. Default: 0. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| training (bool): Training the network or not. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor. | |||
| - **min** (int, float) : Value of the min range of the input data. | |||
| - **max** (int, float) : Value of the max range of the input data. | |||
| Outputs: | |||
| - Tensor, has the same type as input. | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.random.rand(3,4,5,5), mstype.float32) | |||
| >>> min_tensor = Tensor(np.array([-6.0, -6.5, -4.0, -5.0]), mstype.float32) | |||
| >>> max_tensor = Tensor(np.array([6.0, 6.5, 4.0, 5.0]), mstype.float32) | |||
| >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||
| """ | |||
| support_quant_bit = [4, 8] | |||
| channel_idx = 0 | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, | |||
| training=True): | |||
| """init FakeQuantWithMinMaxPerChannel OP""" | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError("Attr \'num_bits\' is not support.") | |||
| if ema and not ema_decay: | |||
| raise ValueError( | |||
| "Attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = check_bool(ema) | |||
| self.symmetric = check_bool(symmetric) | |||
| self.narrow_range = check_bool(narrow_range) | |||
| self.training = check_bool(training) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) | |||
| self.num_bits = check_int_positive(num_bits) | |||
| self.quant_delay = check_int(quant_delay) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | |||
| outputs=['out']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x shape", len(x_shape), 1, Rel.GT) | |||
| validator.check_integer( | |||
| "min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ) | |||
| validator.check_integer( | |||
| "max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| validator.check_typename( | |||
| "x type", x_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename("min type", min_type, | |||
| (mstype.float16, mstype.float32)) | |||
| validator.check_typename("max type", max_type, | |||
| (mstype.float16, mstype.float32)) | |||
| return x_type | |||
| class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): | |||
| """Performs grad of FakeQuantWithMinMaxPerChannel operation.""" | |||
| support_quant_bit = [4, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, quant_delay=0): | |||
| """init FakeQuantWithMinMaxPerChannel Fill""" | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError("Attr \'num_bits\' is not support.") | |||
| self.quant_delay = check_int(quant_delay) | |||
| self.num_bits = check_int_positive(num_bits) | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], | |||
| outputs=['dx']) | |||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | |||
| validator.check("dout shape", dout_shape, "x shape", x_shape) | |||
| validator.check("min shape", min_shape, "max shape", max_shape) | |||
| return dout_shape | |||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | |||
| validator.check_typename( | |||
| "dout", dout_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename("x", x_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename( | |||
| "min", min_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename( | |||
| "max", max_type, (mstype.float16, mstype.float32)) | |||
| return dout_type | |||
| class BatchNormFold(PrimitiveWithInfer): | |||
| """ | |||
| Batch normalization folded. | |||
| Args: | |||
| momentum (float): Momentum value should be [0, 1]. Default: 0.1. | |||
| epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in | |||
| float32 else 1e-3. Default: 1e-12. | |||
| is_training (bool): In training mode set True, else set False. Default: True. | |||
| freeze_bn (int): Delay in steps at which computation switches from regular batch | |||
| norm to frozen mean and std. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||
| - **mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **variance** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **global_step** (Tensor) - Tensor to record current global step. | |||
| Outputs: | |||
| Tuple of 4 Tensor, the normalized input and the updated parameters. | |||
| - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| """ | |||
| channel = 1 | |||
| @prim_attr_register | |||
| def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0): | |||
| """init batch norm fold layer""" | |||
| self.momentum = validator.check_number_range( | |||
| 'momentum', momentum, 0, 1, Rel.INC_BOTH) | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon) | |||
| self.is_training = check_bool(is_training) | |||
| self.freeze_bn = check_int(freeze_bn) | |||
| self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'], | |||
| outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std']) | |||
| def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): | |||
| validator.check("mean shape", mean_shape, | |||
| "gamma_shape", variance_shape) | |||
| validator.check("mean_shape size", | |||
| mean_shape[0], "input channel", x_shape[self.channel]) | |||
| validator.check_integer("global_step shape", | |||
| len(global_step_shape), 1, Rel.EQ) | |||
| return mean_shape, mean_shape, mean_shape, mean_shape | |||
| def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): | |||
| validator.check("input type", x_type, "mean type", mean_type) | |||
| validator.check("input type", x_type, "variance type", variance_type) | |||
| validator.check_typename("input type", x_type, | |||
| (mstype.float16, mstype.float32)) | |||
| validator.check_typename( | |||
| "global_step type", global_step_type, (mstype.int32,)) | |||
| return x_type, x_type, x_type, x_type | |||
| class BatchNormFoldGrad(PrimitiveWithInfer): | |||
| """Performs grad of BatchNormFold operation.""" | |||
| channel = 1 | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0): | |||
| """init BatchNormGrad layer""" | |||
| self.is_training = check_bool(is_training) | |||
| self.freeze_bn = check_int(freeze_bn) | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon) | |||
| self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], | |||
| outputs=['dx']) | |||
| def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape, | |||
| global_step_shape): | |||
| validator.check("d_batch_mean shape", d_batch_mean_shape, | |||
| "d_batch_std shape", d_batch_std_shape) | |||
| validator.check("d_batch_mean shape", d_batch_mean_shape, | |||
| "batch_mean shape", batch_mean_shape) | |||
| validator.check("d_batch_mean shape", d_batch_mean_shape, | |||
| "batch_std shape", batch_std_shape) | |||
| validator.check( | |||
| "x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel]) | |||
| validator.check_integer("global_step shape", | |||
| len(global_step_shape), 1, Rel.EQ) | |||
| return x_shape | |||
| def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, | |||
| global_step_type): | |||
| validator.check("input type", x_type, | |||
| "d_batch_mean type", d_batch_mean_type) | |||
| validator.check("input type", x_type, | |||
| "d_batch_std type", d_batch_std_type) | |||
| validator.check("input type", x_type, | |||
| "batch_mean type", batch_mean_type) | |||
| validator.check("input type", x_type, "batch_std type", batch_std_type) | |||
| validator.check_typename("input type", x_type, | |||
| (mstype.float16, mstype.float32)) | |||
| validator.check_typename( | |||
| "global_step type", global_step_type, (mstype.int32,)) | |||
| return x_type | |||
| class CorrectionMul(PrimitiveWithInfer): | |||
| """ | |||
| Scale the weights with a correction factor to the long term statistics | |||
| prior to quantization. This ensures that there is no jitter in the quantized weights | |||
| due to batch to batch variation. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||
| - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| Outputs: | |||
| - **out** (Tensor) - Tensor has the same shape as x. | |||
| """ | |||
| channel = 0 | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init correction mul layer""" | |||
| self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], | |||
| outputs=['out']) | |||
| def infer_shape(self, x_shape, batch_std_shape, running_std_shape): | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "running_std shape", running_std_shape) | |||
| validator.check( | |||
| "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, batch_std_type, running_std_type): | |||
| validator.check("batch_std type", batch_std_type, | |||
| "running_std type", running_std_type) | |||
| validator.check("batch_std_type", batch_std_type, "x_type", x_type) | |||
| validator.check_typename( | |||
| "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) | |||
| return x_type | |||
| class CorrectionMulGrad(PrimitiveWithInfer): | |||
| """Performs grad of CorrectionMul operation.""" | |||
| channel = 0 | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init correction mul layer""" | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], | |||
| outputs=['dx', 'd_gamma']) | |||
| def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): | |||
| validator.check("dout shape", dout_shape, "x_shape x", x_shape) | |||
| validator.check( | |||
| "gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel]) | |||
| validator.check( | |||
| "running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel]) | |||
| return x_shape, gamma_shape | |||
| def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): | |||
| validator.check("x type", x_type, "dout type", dout_type) | |||
| validator.check("gamma type", gamma_type, "dout type", dout_type) | |||
| validator.check("running_std type", running_std_type, | |||
| "dout type", dout_type) | |||
| validator.check_typename( | |||
| "dout type", dout_type, (mstype.float16, mstype.float32)) | |||
| return x_type, x_type | |||
| class BatchNormFold2(PrimitiveWithInfer): | |||
| """ | |||
| Scale the bias with a correction factor to the long term statistics | |||
| prior to quantization. This ensures that there is no jitter in the quantized bias | |||
| due to batch to batch variation. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||
| - **beta** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **gamma** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **global_step** (Tensor) - Tensor to record current global step. | |||
| Outputs: | |||
| - **y** (Tensor) - Tensor has the same shape as x. | |||
| """ | |||
| channel = 1 | |||
| @prim_attr_register | |||
| def __init__(self, freeze_bn=0): | |||
| """init conv2d fold layer""" | |||
| self.freeze_bn = check_int(freeze_bn) | |||
| self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', | |||
| 'running_std', 'running_mean', 'global_step'], | |||
| outputs=['y']) | |||
| def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape, | |||
| running_mean_shape, global_step_shape): | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "running_std shape", running_std_shape) | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "batch_mean shape", batch_mean_shape) | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "beta shape", beta_shape) | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "running_mean shape", running_mean_shape) | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "batch_mean shape", gamma_shape) | |||
| validator.check( | |||
| "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) | |||
| validator.check_integer("global_step shape", | |||
| len(global_step_shape), 1, Rel.EQ) | |||
| return x_shape | |||
| def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, | |||
| running_mean_type, global_step_type): | |||
| validator.check("batch_std type", batch_std_type, | |||
| "running_std type", running_std_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "batch_mean type", batch_mean_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "beta type", beta_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "running_mean type", running_mean_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "gamma type", gamma_type) | |||
| validator.check("x_type", x_type, "batch_std type", batch_std_type) | |||
| validator.check_typename( | |||
| "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename( | |||
| "global_step type", global_step_type, (mstype.int32,)) | |||
| return x_type | |||
| class BatchNormFold2Grad(PrimitiveWithInfer): | |||
| """Performs grad of CorrectionAddGrad operation.""" | |||
| channel = 1 | |||
| @prim_attr_register | |||
| def __init__(self, freeze_bn=0): | |||
| """init MulFold layer""" | |||
| self.freeze_bn = freeze_bn | |||
| self.init_prim_io_names(inputs=['dout', 'x', 'gamma', | |||
| 'batch_std', 'batch_mean', | |||
| 'running_std', 'running_mean', 'global_step'], | |||
| outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx']) | |||
| def infer_shape(self, dout_shape, x_shape, gamma_shape, | |||
| batch_std_shape, batch_mean_shape, | |||
| running_std_shape, running_mean_shape, global_step_shape): | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "batch_mean shape", batch_mean_shape) | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "running_std shape", running_std_shape) | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "running_mean shape", running_mean_shape) | |||
| validator.check("batch_std shape", batch_std_shape, | |||
| "gamma shape", gamma_shape) | |||
| validator.check( | |||
| "batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel]) | |||
| validator.check_integer("global_step shape", | |||
| len(global_step_shape), 1, Rel.EQ) | |||
| return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape | |||
| def infer_dtype(self, dout_type, x_type, gamma_type, | |||
| batch_std_type, batch_mean_type, | |||
| running_std_type, running_mean_type, global_step_type): | |||
| validator.check("batch_std type", batch_std_type, | |||
| "batch_mean type", batch_mean_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "gamma type", gamma_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "running_std type", running_std_type) | |||
| validator.check("batch_std type", batch_std_type, | |||
| "running_mean type", running_mean_type) | |||
| validator.check("batch_std_type", batch_std_type, | |||
| "dout type", dout_type) | |||
| validator.check_typename( | |||
| "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) | |||
| validator.check_typename( | |||
| "global_step type", global_step_type, (mstype.int32,)) | |||
| return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type | |||
| @@ -207,7 +207,7 @@ class ReLU6(PrimitiveWithInfer): | |||
| class Elu(PrimitiveWithInfer): | |||
| """ | |||
| r""" | |||
| Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise. | |||
| The data type of input tensor should be float. | |||
| @@ -242,6 +242,40 @@ class Elu(PrimitiveWithInfer): | |||
| return input_x | |||
| class HSwish(PrimitiveWithInfer): | |||
| r""" | |||
| Hard swish activation function. | |||
| Applies hswish-type activation element-wise. The input is a Tensor with any valid shape. | |||
| Hard swish is defined as: | |||
| .. math:: | |||
| \text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6}, | |||
| where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. | |||
| Inputs: | |||
| - **input_data** (Tensor) - The input of Hswish. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_data`. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, xshape): | |||
| return xshape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_subclass("x_dtype", x_dtype, mstype.tensor) | |||
| validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) | |||
| return x_dtype | |||
| class Sigmoid(PrimitiveWithInfer): | |||
| r""" | |||
| Sigmoid activation function. | |||
| @@ -258,6 +292,7 @@ class Sigmoid(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, with the same type and shape as the input_x. | |||
| """ | |||
| @prim_attr_register | |||
| @@ -273,6 +308,40 @@ class Sigmoid(PrimitiveWithInfer): | |||
| return input_x | |||
| class HSigmoid(PrimitiveWithInfer): | |||
| r""" | |||
| Hard sigmoid activation function. | |||
| Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape. | |||
| Hard sigmoid is defined as: | |||
| .. math:: | |||
| \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), | |||
| where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. | |||
| Inputs: | |||
| - **input_data** (Tensor) - The input of HSigmoid. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_data`. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_subclass("x_dtype", x_dtype, mstype.tensor) | |||
| validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) | |||
| return x_dtype | |||
| class Tanh(PrimitiveWithInfer): | |||
| r""" | |||
| Tanh activation function. | |||
| @@ -27,11 +27,6 @@ def test_dense_none(): | |||
| nn.Dense(3, 2, None, None) | |||
| def test_dense_invalid_activation(): | |||
| with pytest.raises(KeyError): | |||
| nn.Dense(3, 2, activation='relu6') | |||
| @non_graph_engine | |||
| def test_dense_str_activation(): | |||
| dense = nn.Dense(1, 1, activation='relu') | |||
| @@ -51,11 +51,6 @@ def test_activation_empty(): | |||
| assert nn.get_activation('') is None | |||
| def test_activation_invalid(): | |||
| with pytest.raises(KeyError): | |||
| nn.get_activation('relu6') | |||
| # test softmax | |||
| def test_softmax_axis(): | |||
| layer = nn.Softmax(1) | |||
| @@ -68,11 +68,6 @@ def test_dense_none(): | |||
| nn.Dense(3, 2, None, None) | |||
| def test_dense_invalid_activation(): | |||
| with pytest.raises(KeyError): | |||
| nn.Dense(3, 2, activation='relu6') | |||
| def test_dense_str_activation(): | |||
| dense = nn.Dense(1, 1, activation='relu') | |||
| assert isinstance(dense.activation, nn.ReLU) | |||