| @@ -181,11 +181,11 @@ class ExportToQuantInferNetwork: | |||
| cell_core = None | |||
| fake_quant_act = None | |||
| activation = None | |||
| if isinstance(subcell, quant.Conv2dBnAct): | |||
| if isinstance(subcell, nn.Conv2dBnAct): | |||
| cell_core = subcell.conv | |||
| activation = subcell.activation | |||
| fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None | |||
| elif isinstance(subcell, quant.DenseBnAct): | |||
| elif isinstance(subcell, nn.DenseBnAct): | |||
| cell_core = subcell.dense | |||
| activation = subcell.activation | |||
| fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None | |||
| @@ -240,9 +240,9 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): | |||
| subcell = cells[name] | |||
| if subcell == network: | |||
| continue | |||
| if isinstance(subcell, quant.Conv2dBnAct): | |||
| if isinstance(subcell, nn.Conv2dBnAct): | |||
| network, change = self._convert_subcell(network, change, name, subcell) | |||
| elif isinstance(subcell, quant.DenseBnAct): | |||
| elif isinstance(subcell, nn.DenseBnAct): | |||
| network, change = self._convert_subcell(network, change, name, subcell, conv=False) | |||
| elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, | |||
| quant.Conv2dQuant, quant.DenseQuant)): | |||
| @@ -36,7 +36,7 @@ from .quantizer import Quantizer, OptimizeOption | |||
| __all__ = ["QuantizationAwareTraining", "create_quant_config"] | |||
| def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), | |||
| def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver), | |||
| quant_delay=(0, 0), | |||
| quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), | |||
| per_channel=(False, False), | |||
| @@ -48,7 +48,7 @@ def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant | |||
| Args: | |||
| quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent | |||
| weights and second element represent data flow. | |||
| Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) | |||
| Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver) | |||
| quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during | |||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||
| quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first | |||
| @@ -210,8 +210,8 @@ class QuantizationAwareTraining(Quantizer): | |||
| self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") | |||
| self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") | |||
| self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") | |||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | |||
| quant.DenseBnAct: self._convert_dense} | |||
| self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv, | |||
| nn.DenseBnAct: self._convert_dense} | |||
| self.quant_config = create_quant_config(quant_delay=quant_delay, | |||
| quant_dtype=quant_dtype, | |||
| per_channel=per_channel, | |||
| @@ -257,7 +257,7 @@ class QuantizationAwareTraining(Quantizer): | |||
| subcell = cells[name] | |||
| if subcell == network: | |||
| continue | |||
| elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): | |||
| elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)): | |||
| prefix = subcell.param_prefix | |||
| new_subcell = self._convert_method_map[type(subcell)](subcell) | |||
| new_subcell.update_parameters_name(prefix + '.') | |||
| @@ -17,7 +17,7 @@ Layer. | |||
| The high-level components(Cells) used to construct the neural network. | |||
| """ | |||
| from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math | |||
| from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math, combined | |||
| from .activation import * | |||
| from .normalization import * | |||
| from .container import * | |||
| @@ -29,6 +29,7 @@ from .pooling import * | |||
| from .image import * | |||
| from .quant import * | |||
| from .math import * | |||
| from .combined import * | |||
| __all__ = [] | |||
| __all__.extend(activation.__all__) | |||
| @@ -42,3 +43,4 @@ __all__.extend(pooling.__all__) | |||
| __all__.extend(image.__all__) | |||
| __all__.extend(quant.__all__) | |||
| __all__.extend(math.__all__) | |||
| __all__.extend(combined.__all__) | |||
| @@ -0,0 +1,215 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Combined cells.""" | |||
| from mindspore import nn | |||
| from mindspore.ops.primitive import Primitive | |||
| from mindspore._checkparam import Validator | |||
| from .normalization import BatchNorm2d, BatchNorm1d | |||
| from .activation import get_activation, LeakyReLU | |||
| from ..cell import Cell | |||
| __all__ = [ | |||
| 'Conv2dBnAct', | |||
| 'DenseBnAct' | |||
| ] | |||
| class Conv2dBnAct(Cell): | |||
| r""" | |||
| A combination of convolution, Batchnorm, activation layer. | |||
| This part is a more detailed overview of Conv2d op. | |||
| Args: | |||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||
| kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height | |||
| and width of the 2D convolution window. Single int means the value is for both height and width of | |||
| the kernel. A tuple of 2 ints means the first value is for the height and the other is for the | |||
| width of the kernel. | |||
| stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be | |||
| greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1. | |||
| pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||
| padding (int): Implicit paddings on both sides of the input. Default: 0. | |||
| dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, | |||
| there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than | |||
| or equal to 1 and lower than any one of the height and width of the input. Default: 1. | |||
| group (int): Splits filter into groups, `in_ channels` and `out_channels` must be | |||
| divisible by the number of groups. Default: 1. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||
| It can be a Tensor, a string, an Initializer or a number. When a string is specified, | |||
| values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well | |||
| as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' | |||
| and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of | |||
| Initializer for more details. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible | |||
| Initializer and string are the same as 'weight_init'. Refer to the values of | |||
| Initializer for more details. Default: 'zeros'. | |||
| has_bn (bool): Specifies to used batchnorm or not. Default: False. | |||
| momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 | |||
| eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater | |||
| than 0. Default: 1e-5. | |||
| activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: | |||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||
| alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. | |||
| after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU') | |||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | |||
| >>> result = net(input) | |||
| >>> result.shape | |||
| (1, 240, 1024, 640) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| pad_mode='same', | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| has_bias=False, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bn=False, | |||
| momentum=0.9, | |||
| eps=1e-5, | |||
| activation=None, | |||
| alpha=0.2, | |||
| after_fake=True): | |||
| super(Conv2dBnAct, self).__init__() | |||
| self.conv = nn.Conv2d(in_channels, | |||
| out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| pad_mode=pad_mode, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| group=group, | |||
| has_bias=has_bias, | |||
| weight_init=weight_init, | |||
| bias_init=bias_init) | |||
| self.has_bn = Validator.check_bool(has_bn, "has_bn") | |||
| self.has_act = activation is not None | |||
| self.after_fake = Validator.check_bool(after_fake, "after_fake") | |||
| if has_bn: | |||
| self.batchnorm = BatchNorm2d(out_channels, eps, momentum) | |||
| if activation == "leakyrelu": | |||
| self.activation = LeakyReLU(alpha) | |||
| else: | |||
| self.activation = get_activation(activation) if isinstance(activation, str) else activation | |||
| if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | |||
| raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | |||
| def construct(self, x): | |||
| x = self.conv(x) | |||
| if self.has_bn: | |||
| x = self.batchnorm(x) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| return x | |||
| class DenseBnAct(Cell): | |||
| r""" | |||
| A combination of Dense, Batchnorm, and the activation layer. | |||
| This part is a more detailed overview of Dense op. | |||
| Args: | |||
| in_channels (int): The number of channels in the input space. | |||
| out_channels (int): The number of channels in the output space. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||
| is same as input. The values of str refer to the function `initializer`. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||
| same as input. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None. | |||
| has_bn (bool): Specifies to use batchnorm or not. Default: False. | |||
| momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 | |||
| eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater | |||
| than 0. Default: 1e-5. | |||
| activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: | |||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||
| alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. | |||
| after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, out\_channels)`. | |||
| Examples: | |||
| >>> net = nn.DenseBnAct(3, 4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||
| >>> result = net(input) | |||
| >>> result.shape | |||
| (2, 4) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bias=True, | |||
| has_bn=False, | |||
| momentum=0.9, | |||
| eps=1e-5, | |||
| activation=None, | |||
| alpha=0.2, | |||
| after_fake=True): | |||
| super(DenseBnAct, self).__init__() | |||
| self.dense = nn.Dense( | |||
| in_channels, | |||
| out_channels, | |||
| weight_init, | |||
| bias_init, | |||
| has_bias) | |||
| self.has_bn = Validator.check_bool(has_bn, "has_bn") | |||
| self.has_act = activation is not None | |||
| self.after_fake = Validator.check_bool(after_fake, "after_fake") | |||
| if has_bn: | |||
| self.batchnorm = BatchNorm1d(out_channels, eps, momentum) | |||
| if activation == "leakyrelu": | |||
| self.activation = LeakyReLU(alpha) | |||
| else: | |||
| self.activation = get_activation(activation) if isinstance(activation, str) else activation | |||
| if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | |||
| raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| if self.has_bn: | |||
| x = self.batchnorm(x) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| return x | |||
| @@ -17,7 +17,6 @@ | |||
| from functools import partial | |||
| from collections import namedtuple | |||
| import numpy as np | |||
| from mindspore import nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops.primitive import Primitive | |||
| from mindspore.ops import operations as P | |||
| @@ -28,14 +27,12 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import Validator, Rel, twice | |||
| from mindspore.compression.common import QuantDtype | |||
| import mindspore.context as context | |||
| from .normalization import BatchNorm2d, BatchNorm1d | |||
| from .activation import get_activation, ReLU, LeakyReLU | |||
| from .normalization import BatchNorm2d | |||
| from .activation import get_activation, ReLU | |||
| from ..cell import Cell | |||
| from ...ops.operations import _quant_ops as Q | |||
| __all__ = [ | |||
| 'Conv2dBnAct', | |||
| 'DenseBnAct', | |||
| 'FakeQuantWithMinMaxObserver', | |||
| 'Conv2dBnFoldQuant', | |||
| 'Conv2dBnWithoutFoldQuant', | |||
| @@ -47,192 +44,6 @@ __all__ = [ | |||
| ] | |||
| class Conv2dBnAct(Cell): | |||
| r""" | |||
| A combination of convolution, Batchnorm, activation layer. | |||
| This part is a more detailed overview of Conv2d op. | |||
| Args: | |||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||
| kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height | |||
| and width of the 2D convolution window. Single int means the value is for both height and width of | |||
| the kernel. A tuple of 2 ints means the first value is for the height and the other is for the | |||
| width of the kernel. | |||
| stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be | |||
| greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1. | |||
| pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||
| padding (int): Implicit paddings on both sides of the input. Default: 0. | |||
| dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, | |||
| there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than | |||
| or equal to 1 and lower than any one of the height and width of the input. Default: 1. | |||
| group (int): Splits filter into groups, `in_ channels` and `out_channels` must be | |||
| divisible by the number of groups. Default: 1. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||
| It can be a Tensor, a string, an Initializer or a number. When a string is specified, | |||
| values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well | |||
| as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' | |||
| and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of | |||
| Initializer for more details. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible | |||
| Initializer and string are the same as 'weight_init'. Refer to the values of | |||
| Initializer for more details. Default: 'zeros'. | |||
| has_bn (bool): Specifies to used batchnorm or not. Default: False. | |||
| momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 | |||
| eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater | |||
| than 0. Default: 1e-5. | |||
| activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: | |||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||
| alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. | |||
| after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU') | |||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | |||
| >>> result = net(input) | |||
| >>> result.shape | |||
| (1, 240, 1024, 640) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| pad_mode='same', | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| has_bias=False, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bn=False, | |||
| momentum=0.9, | |||
| eps=1e-5, | |||
| activation=None, | |||
| alpha=0.2, | |||
| after_fake=True): | |||
| super(Conv2dBnAct, self).__init__() | |||
| self.conv = nn.Conv2d(in_channels, | |||
| out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| pad_mode=pad_mode, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| group=group, | |||
| has_bias=has_bias, | |||
| weight_init=weight_init, | |||
| bias_init=bias_init) | |||
| self.has_bn = Validator.check_bool(has_bn, "has_bn") | |||
| self.has_act = activation is not None | |||
| self.after_fake = Validator.check_bool(after_fake, "after_fake") | |||
| if has_bn: | |||
| self.batchnorm = BatchNorm2d(out_channels, eps, momentum) | |||
| if activation == "leakyrelu": | |||
| self.activation = LeakyReLU(alpha) | |||
| else: | |||
| self.activation = get_activation(activation) if isinstance(activation, str) else activation | |||
| if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | |||
| raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | |||
| def construct(self, x): | |||
| x = self.conv(x) | |||
| if self.has_bn: | |||
| x = self.batchnorm(x) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| return x | |||
| class DenseBnAct(Cell): | |||
| r""" | |||
| A combination of Dense, Batchnorm, and the activation layer. | |||
| This part is a more detailed overview of Dense op. | |||
| Args: | |||
| in_channels (int): The number of channels in the input space. | |||
| out_channels (int): The number of channels in the output space. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||
| is same as input. The values of str refer to the function `initializer`. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||
| same as input. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None. | |||
| has_bn (bool): Specifies to use batchnorm or not. Default: False. | |||
| momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9 | |||
| eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater | |||
| than 0. Default: 1e-5. | |||
| activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following: | |||
| 'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', | |||
| 'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None. | |||
| alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2. | |||
| after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, out\_channels)`. | |||
| Examples: | |||
| >>> net = nn.DenseBnAct(3, 4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||
| >>> result = net(input) | |||
| >>> result.shape | |||
| (2, 4) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bias=True, | |||
| has_bn=False, | |||
| momentum=0.9, | |||
| eps=1e-5, | |||
| activation=None, | |||
| alpha=0.2, | |||
| after_fake=True): | |||
| super(DenseBnAct, self).__init__() | |||
| self.dense = nn.Dense( | |||
| in_channels, | |||
| out_channels, | |||
| weight_init, | |||
| bias_init, | |||
| has_bias) | |||
| self.has_bn = Validator.check_bool(has_bn, "has_bn") | |||
| self.has_act = activation is not None | |||
| self.after_fake = Validator.check_bool(after_fake, "after_fake") | |||
| if has_bn: | |||
| self.batchnorm = BatchNorm1d(out_channels, eps, momentum) | |||
| if activation == "leakyrelu": | |||
| self.activation = LeakyReLU(alpha) | |||
| self.activation = get_activation(activation) if isinstance(activation, str) else activation | |||
| if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | |||
| raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| if self.has_bn: | |||
| x = self.batchnorm(x) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| return x | |||
| class BatchNormFoldCell(Cell): | |||
| """ | |||
| Batch normalization folded. | |||