diff --git a/mindspore/compression/export/quant_export.py b/mindspore/compression/export/quant_export.py index 766bc9eee4..2aa8284cbe 100644 --- a/mindspore/compression/export/quant_export.py +++ b/mindspore/compression/export/quant_export.py @@ -181,11 +181,11 @@ class ExportToQuantInferNetwork: cell_core = None fake_quant_act = None activation = None - if isinstance(subcell, quant.Conv2dBnAct): + if isinstance(subcell, nn.Conv2dBnAct): cell_core = subcell.conv activation = subcell.activation fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None - elif isinstance(subcell, quant.DenseBnAct): + elif isinstance(subcell, nn.DenseBnAct): cell_core = subcell.dense activation = subcell.activation fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None @@ -240,9 +240,9 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): subcell = cells[name] if subcell == network: continue - if isinstance(subcell, quant.Conv2dBnAct): + if isinstance(subcell, nn.Conv2dBnAct): network, change = self._convert_subcell(network, change, name, subcell) - elif isinstance(subcell, quant.DenseBnAct): + elif isinstance(subcell, nn.DenseBnAct): network, change = self._convert_subcell(network, change, name, subcell, conv=False) elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)): diff --git a/mindspore/compression/quant/qat.py b/mindspore/compression/quant/qat.py index f7fe7cfa69..0a6156a010 100644 --- a/mindspore/compression/quant/qat.py +++ b/mindspore/compression/quant/qat.py @@ -36,7 +36,7 @@ from .quantizer import Quantizer, OptimizeOption __all__ = ["QuantizationAwareTraining", "create_quant_config"] -def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), +def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver), quant_delay=(0, 0), quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), per_channel=(False, False), @@ -48,7 +48,7 @@ def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant Args: quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent weights and second element represent data flow. - Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) + Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver) quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during eval. The first element represent weights and second element represent data flow. Default: (0, 0) quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first @@ -210,8 +210,8 @@ class QuantizationAwareTraining(Quantizer): self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") - self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, - quant.DenseBnAct: self._convert_dense} + self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv, + nn.DenseBnAct: self._convert_dense} self.quant_config = create_quant_config(quant_delay=quant_delay, quant_dtype=quant_dtype, per_channel=per_channel, @@ -257,7 +257,7 @@ class QuantizationAwareTraining(Quantizer): subcell = cells[name] if subcell == network: continue - elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): + elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)): prefix = subcell.param_prefix new_subcell = self._convert_method_map[type(subcell)](subcell) new_subcell.update_parameters_name(prefix + '.') diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 9999142a42..4aecf85d97 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -17,7 +17,7 @@ Layer. The high-level components(Cells) used to construct the neural network. """ -from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math +from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math, combined from .activation import * from .normalization import * from .container import * @@ -29,6 +29,7 @@ from .pooling import * from .image import * from .quant import * from .math import * +from .combined import * __all__ = [] __all__.extend(activation.__all__) @@ -42,3 +43,4 @@ __all__.extend(pooling.__all__) __all__.extend(image.__all__) __all__.extend(quant.__all__) __all__.extend(math.__all__) +__all__.extend(combined.__all__) diff --git a/mindspore/nn/layer/combined.py b/mindspore/nn/layer/combined.py new file mode 100644 index 0000000000..8acdd0f8e4 --- /dev/null +++ b/mindspore/nn/layer/combined.py @@ -0,0 +1,215 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Combined cells.""" + +from mindspore import nn +from mindspore.ops.primitive import Primitive +from mindspore._checkparam import Validator +from .normalization import BatchNorm2d, BatchNorm1d +from .activation import get_activation, LeakyReLU +from ..cell import Cell + + +__all__ = [ + 'Conv2dBnAct', + 'DenseBnAct' +] + + +class Conv2dBnAct(Cell): + r""" + A combination of convolution, Batchnorm, activation layer. + + This part is a more detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height + and width of the 2D convolution window. Single int means the value is for both height and width of + the kernel. A tuple of 2 ints means the first value is for the height and the other is for the + width of the kernel. + stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be + greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, + there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than + or equal to 1 and lower than any one of the height and width of the input. Default: 1. + group (int): Splits filter into groups, `in_ channels` and `out_channels` must be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + has_bn (bool): Specifies to used batchnorm or not. Default: False. + momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 + eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater + than 0. Default: 1e-5. + activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. + after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU') + >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) + >>> result = net(input) + >>> result.shape + (1, 240, 1024, 640) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros', + has_bn=False, + momentum=0.9, + eps=1e-5, + activation=None, + alpha=0.2, + after_fake=True): + super(Conv2dBnAct, self).__init__() + + self.conv = nn.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=group, + has_bias=has_bias, + weight_init=weight_init, + bias_init=bias_init) + self.has_bn = Validator.check_bool(has_bn, "has_bn") + self.has_act = activation is not None + self.after_fake = Validator.check_bool(after_fake, "after_fake") + if has_bn: + self.batchnorm = BatchNorm2d(out_channels, eps, momentum) + if activation == "leakyrelu": + self.activation = LeakyReLU(alpha) + else: + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) + + def construct(self, x): + x = self.conv(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x + + +class DenseBnAct(Cell): + r""" + A combination of Dense, Batchnorm, and the activation layer. + + This part is a more detailed overview of Dense op. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None. + has_bn (bool): Specifies to use batchnorm or not. Default: False. + momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 + eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater + than 0. Default: 1e-5. + activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. + after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. + + Outputs: + Tensor of shape :math:`(N, out\_channels)`. + + Examples: + >>> net = nn.DenseBnAct(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> result = net(input) + >>> result.shape + (2, 4) + """ + + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True, + has_bn=False, + momentum=0.9, + eps=1e-5, + activation=None, + alpha=0.2, + after_fake=True): + super(DenseBnAct, self).__init__() + self.dense = nn.Dense( + in_channels, + out_channels, + weight_init, + bias_init, + has_bias) + self.has_bn = Validator.check_bool(has_bn, "has_bn") + self.has_act = activation is not None + self.after_fake = Validator.check_bool(after_fake, "after_fake") + if has_bn: + self.batchnorm = BatchNorm1d(out_channels, eps, momentum) + if activation == "leakyrelu": + self.activation = LeakyReLU(alpha) + else: + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) + + def construct(self, x): + x = self.dense(x) + if self.has_bn: + x = self.batchnorm(x) + if self.has_act: + x = self.activation(x) + return x diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 7d7090b425..773c8cc84a 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -17,7 +17,6 @@ from functools import partial from collections import namedtuple import numpy as np -from mindspore import nn import mindspore.common.dtype as mstype from mindspore.ops.primitive import Primitive from mindspore.ops import operations as P @@ -28,14 +27,12 @@ from mindspore.common.tensor import Tensor from mindspore._checkparam import Validator, Rel, twice from mindspore.compression.common import QuantDtype import mindspore.context as context -from .normalization import BatchNorm2d, BatchNorm1d -from .activation import get_activation, ReLU, LeakyReLU +from .normalization import BatchNorm2d +from .activation import get_activation, ReLU from ..cell import Cell from ...ops.operations import _quant_ops as Q __all__ = [ - 'Conv2dBnAct', - 'DenseBnAct', 'FakeQuantWithMinMaxObserver', 'Conv2dBnFoldQuant', 'Conv2dBnWithoutFoldQuant', @@ -47,192 +44,6 @@ __all__ = [ ] -class Conv2dBnAct(Cell): - r""" - A combination of convolution, Batchnorm, activation layer. - - This part is a more detailed overview of Conv2d op. - - Args: - in_channels (int): The number of input channel :math:`C_{in}`. - out_channels (int): The number of output channel :math:`C_{out}`. - kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height - and width of the 2D convolution window. Single int means the value is for both height and width of - the kernel. A tuple of 2 ints means the first value is for the height and the other is for the - width of the kernel. - stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be - greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1. - pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". - padding (int): Implicit paddings on both sides of the input. Default: 0. - dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, - there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than - or equal to 1 and lower than any one of the height and width of the input. Default: 1. - group (int): Splits filter into groups, `in_ channels` and `out_channels` must be - divisible by the number of groups. Default: 1. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. - It can be a Tensor, a string, an Initializer or a number. When a string is specified, - values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well - as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' - and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of - Initializer for more details. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible - Initializer and string are the same as 'weight_init'. Refer to the values of - Initializer for more details. Default: 'zeros'. - has_bn (bool): Specifies to used batchnorm or not. Default: False. - momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 - eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater - than 0. Default: 1e-5. - activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: - 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', - 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. - alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. - after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct. - - Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. - - Examples: - >>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU') - >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) - >>> result = net(input) - >>> result.shape - (1, 240, 1024, 640) - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - group=1, - has_bias=False, - weight_init='normal', - bias_init='zeros', - has_bn=False, - momentum=0.9, - eps=1e-5, - activation=None, - alpha=0.2, - after_fake=True): - super(Conv2dBnAct, self).__init__() - - self.conv = nn.Conv2d(in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - pad_mode=pad_mode, - padding=padding, - dilation=dilation, - group=group, - has_bias=has_bias, - weight_init=weight_init, - bias_init=bias_init) - self.has_bn = Validator.check_bool(has_bn, "has_bn") - self.has_act = activation is not None - self.after_fake = Validator.check_bool(after_fake, "after_fake") - if has_bn: - self.batchnorm = BatchNorm2d(out_channels, eps, momentum) - if activation == "leakyrelu": - self.activation = LeakyReLU(alpha) - else: - self.activation = get_activation(activation) if isinstance(activation, str) else activation - if activation is not None and not isinstance(self.activation, (Cell, Primitive)): - raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) - - def construct(self, x): - x = self.conv(x) - if self.has_bn: - x = self.batchnorm(x) - if self.has_act: - x = self.activation(x) - return x - - -class DenseBnAct(Cell): - r""" - A combination of Dense, Batchnorm, and the activation layer. - - This part is a more detailed overview of Dense op. - - Args: - in_channels (int): The number of channels in the input space. - out_channels (int): The number of channels in the output space. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype - is same as input. The values of str refer to the function `initializer`. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is - same as input. The values of str refer to the function `initializer`. Default: 'zeros'. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None. - has_bn (bool): Specifies to use batchnorm or not. Default: False. - momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 - eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater - than 0. Default: 1e-5. - activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: - 'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', - 'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None. - alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. - after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct. - - Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. - - Outputs: - Tensor of shape :math:`(N, out\_channels)`. - - Examples: - >>> net = nn.DenseBnAct(3, 4) - >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) - >>> result = net(input) - >>> result.shape - (2, 4) - """ - - def __init__(self, - in_channels, - out_channels, - weight_init='normal', - bias_init='zeros', - has_bias=True, - has_bn=False, - momentum=0.9, - eps=1e-5, - activation=None, - alpha=0.2, - after_fake=True): - super(DenseBnAct, self).__init__() - self.dense = nn.Dense( - in_channels, - out_channels, - weight_init, - bias_init, - has_bias) - self.has_bn = Validator.check_bool(has_bn, "has_bn") - self.has_act = activation is not None - self.after_fake = Validator.check_bool(after_fake, "after_fake") - if has_bn: - self.batchnorm = BatchNorm1d(out_channels, eps, momentum) - if activation == "leakyrelu": - self.activation = LeakyReLU(alpha) - self.activation = get_activation(activation) if isinstance(activation, str) else activation - if activation is not None and not isinstance(self.activation, (Cell, Primitive)): - raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) - - def construct(self, x): - x = self.dense(x) - if self.has_bn: - x = self.batchnorm(x) - if self.has_act: - x = self.activation(x) - return x - - class BatchNormFoldCell(Cell): """ Batch normalization folded.