Merge pull request !2168 from chenzhongming/combinedtags/v0.5.0-beta
| @@ -1,182 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Use combination of Conv, Dense, Relu, Batchnorm.""" | |||||
| from .normalization import BatchNorm2d | |||||
| from .activation import get_activation | |||||
| from ..cell import Cell | |||||
| from . import conv, basic | |||||
| from ..._checkparam import ParamValidator as validator | |||||
| __all__ = ['Conv2d', 'Dense'] | |||||
| class Conv2d(Cell): | |||||
| r""" | |||||
| A combination of convolution, Batchnorm, activation layer. | |||||
| For a more Detailed overview of Conv2d op. | |||||
| Args: | |||||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||||
| kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height | |||||
| and width of the 2D convolution window. Single int means the value if for both height and width of | |||||
| the kernel. A tuple of 2 ints means the first value is for the height and the other is for the | |||||
| width of the kernel. | |||||
| stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be | |||||
| greater or equal to 1 but bounded by the height and width of the input. Default: 1. | |||||
| pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||||
| padding (int): Implicit paddings on both sides of the input. Default: 0. | |||||
| dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, | |||||
| there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater | |||||
| or equal to 1 and bounded by the height and width of the input. Default: 1. | |||||
| group (int): Split filter into groups, `in_ channels` and `out_channels` should be | |||||
| divisible by the number of groups. Default: 1. | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||||
| It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, | |||||
| values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well | |||||
| as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' | |||||
| and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of | |||||
| Initializer for more details. Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible | |||||
| Initializer and string are the same as 'weight_init'. Refer to the values of | |||||
| Initializer for more details. Default: 'zeros'. | |||||
| batchnorm (bool): Specifies to used batchnorm or not. Default: None. | |||||
| activation (string): Specifies activation type. The optional values are as following: | |||||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||||
| Inputs: | |||||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||||
| Outputs: | |||||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||||
| Examples: | |||||
| >>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU') | |||||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | |||||
| >>> net(input).shape | |||||
| (1, 240, 1024, 640) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| pad_mode='same', | |||||
| padding=0, | |||||
| dilation=1, | |||||
| group=1, | |||||
| has_bias=False, | |||||
| weight_init='normal', | |||||
| bias_init='zeros', | |||||
| batchnorm=None, | |||||
| activation=None): | |||||
| super(Conv2d, self).__init__() | |||||
| self.conv = conv.Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| pad_mode, | |||||
| padding, | |||||
| dilation, | |||||
| group, | |||||
| has_bias, | |||||
| weight_init, | |||||
| bias_init) | |||||
| self.has_bn = batchnorm is not None | |||||
| self.has_act = activation is not None | |||||
| self.batchnorm = batchnorm | |||||
| if batchnorm is True: | |||||
| self.batchnorm = BatchNorm2d(out_channels) | |||||
| elif batchnorm is not None: | |||||
| validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) | |||||
| self.activation = get_activation(activation) | |||||
| def construct(self, x): | |||||
| x = self.conv(x) | |||||
| if self.has_bn: | |||||
| x = self.batchnorm(x) | |||||
| if self.has_act: | |||||
| x = self.activation(x) | |||||
| return x | |||||
| class Dense(Cell): | |||||
| r""" | |||||
| A combination of Dense, Batchnorm, activation layer. | |||||
| For a more Detailed overview of Dense op. | |||||
| Args: | |||||
| in_channels (int): The number of channels in the input space. | |||||
| out_channels (int): The number of channels in the output space. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||||
| is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||||
| batchnorm (bool): Specifies to used batchnorm or not. Default: None. | |||||
| activation (string): Specifies activation type. The optional values are as following: | |||||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||||
| Inputs: | |||||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | |||||
| Outputs: | |||||
| Tensor of shape :math:`(N, out\_channels)`. | |||||
| Examples: | |||||
| >>> net = nn.Dense(3, 4) | |||||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||||
| >>> net(input) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| weight_init='normal', | |||||
| bias_init='zeros', | |||||
| has_bias=True, | |||||
| batchnorm=None, | |||||
| activation=None): | |||||
| super(Dense, self).__init__() | |||||
| self.dense = basic.Dense( | |||||
| in_channels, | |||||
| out_channels, | |||||
| weight_init, | |||||
| bias_init, | |||||
| has_bias) | |||||
| self.has_bn = batchnorm is not None | |||||
| self.has_act = activation is not None | |||||
| if batchnorm is True: | |||||
| self.batchnorm = BatchNorm2d(out_channels) | |||||
| elif batchnorm is not None: | |||||
| validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) | |||||
| self.activation = get_activation(activation) | |||||
| def construct(self, x): | |||||
| x = self.dense(x) | |||||
| if self.has_bn: | |||||
| x = self.batchnorm(x) | |||||
| if self.has_act: | |||||
| x = self.activation(x) | |||||
| return x | |||||
| @@ -27,8 +27,16 @@ from mindspore._checkparam import Validator as validator, Rel | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.nn.layer.activation import get_activation | from mindspore.nn.layer.activation import get_activation | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from .normalization import BatchNorm2d | |||||
| from .activation import get_activation | |||||
| from ..cell import Cell | |||||
| from . import conv, basic | |||||
| from ..._checkparam import ParamValidator as validator | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Conv2dBnAct', | |||||
| 'DenseBnAct', | |||||
| 'FakeQuantWithMinMax', | 'FakeQuantWithMinMax', | ||||
| 'Conv2dBatchNormQuant', | 'Conv2dBatchNormQuant', | ||||
| 'Conv2dQuant', | 'Conv2dQuant', | ||||
| @@ -42,6 +50,165 @@ __all__ = [ | |||||
| ] | ] | ||||
| class Conv2dBnAct(Cell): | |||||
| r""" | |||||
| A combination of convolution, Batchnorm, activation layer. | |||||
| For a more Detailed overview of Conv2d op. | |||||
| Args: | |||||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||||
| kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height | |||||
| and width of the 2D convolution window. Single int means the value if for both height and width of | |||||
| the kernel. A tuple of 2 ints means the first value is for the height and the other is for the | |||||
| width of the kernel. | |||||
| stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be | |||||
| greater or equal to 1 but bounded by the height and width of the input. Default: 1. | |||||
| pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||||
| padding (int): Implicit paddings on both sides of the input. Default: 0. | |||||
| dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, | |||||
| there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater | |||||
| or equal to 1 and bounded by the height and width of the input. Default: 1. | |||||
| group (int): Split filter into groups, `in_ channels` and `out_channels` should be | |||||
| divisible by the number of groups. Default: 1. | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||||
| It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, | |||||
| values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well | |||||
| as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' | |||||
| and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of | |||||
| Initializer for more details. Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible | |||||
| Initializer and string are the same as 'weight_init'. Refer to the values of | |||||
| Initializer for more details. Default: 'zeros'. | |||||
| batchnorm (bool): Specifies to used batchnorm or not. Default: None. | |||||
| activation (string): Specifies activation type. The optional values are as following: | |||||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||||
| Inputs: | |||||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||||
| Outputs: | |||||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||||
| Examples: | |||||
| >>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU') | |||||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | |||||
| >>> net(input).shape | |||||
| (1, 240, 1024, 640) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| pad_mode='same', | |||||
| padding=0, | |||||
| dilation=1, | |||||
| group=1, | |||||
| has_bias=False, | |||||
| weight_init='normal', | |||||
| bias_init='zeros', | |||||
| batchnorm=None, | |||||
| activation=None): | |||||
| super(Conv2dBnAct, self).__init__() | |||||
| self.conv = conv.Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| pad_mode, | |||||
| padding, | |||||
| dilation, | |||||
| group, | |||||
| has_bias, | |||||
| weight_init, | |||||
| bias_init) | |||||
| self.has_bn = batchnorm is not None | |||||
| self.has_act = activation is not None | |||||
| self.batchnorm = batchnorm | |||||
| if batchnorm is True: | |||||
| self.batchnorm = BatchNorm2d(out_channels) | |||||
| elif batchnorm is not None: | |||||
| validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) | |||||
| self.activation = get_activation(activation) | |||||
| def construct(self, x): | |||||
| x = self.conv(x) | |||||
| if self.has_bn: | |||||
| x = self.batchnorm(x) | |||||
| if self.has_act: | |||||
| x = self.activation(x) | |||||
| return x | |||||
| class DenseBnAct(Cell): | |||||
| r""" | |||||
| A combination of Dense, Batchnorm, activation layer. | |||||
| For a more Detailed overview of Dense op. | |||||
| Args: | |||||
| in_channels (int): The number of channels in the input space. | |||||
| out_channels (int): The number of channels in the output space. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||||
| is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||||
| batchnorm (bool): Specifies to used batchnorm or not. Default: None. | |||||
| activation (string): Specifies activation type. The optional values are as following: | |||||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||||
| Inputs: | |||||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | |||||
| Outputs: | |||||
| Tensor of shape :math:`(N, out\_channels)`. | |||||
| Examples: | |||||
| >>> net = nn.Dense(3, 4) | |||||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||||
| >>> net(input) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| weight_init='normal', | |||||
| bias_init='zeros', | |||||
| has_bias=True, | |||||
| batchnorm=None, | |||||
| activation=None): | |||||
| super(DenseBnAct, self).__init__() | |||||
| self.dense = basic.Dense( | |||||
| in_channels, | |||||
| out_channels, | |||||
| weight_init, | |||||
| bias_init, | |||||
| has_bias) | |||||
| self.has_bn = batchnorm is not None | |||||
| self.has_act = activation is not None | |||||
| if batchnorm is True: | |||||
| self.batchnorm = BatchNorm2d(out_channels) | |||||
| elif batchnorm is not None: | |||||
| validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) | |||||
| self.activation = get_activation(activation) | |||||
| def construct(self, x): | |||||
| x = self.dense(x) | |||||
| if self.has_bn: | |||||
| x = self.batchnorm(x) | |||||
| if self.has_act: | |||||
| x = self.activation(x) | |||||
| return x | |||||
| class BatchNormFoldCell(Cell): | class BatchNormFoldCell(Cell): | ||||
| """ | """ | ||||
| Batch normalization folded. | Batch normalization folded. | ||||
| @@ -302,8 +469,8 @@ class Conv2dBatchNormQuant(Cell): | |||||
| # initialize convolution op and Parameter | # initialize convolution op and Parameter | ||||
| if context.get_context('device_target') == "Ascend" and group > 1: | if context.get_context('device_target') == "Ascend" and group > 1: | ||||
| validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant') | |||||
| validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant') | |||||
| validator.check_integer('group', group, in_channels, Rel.EQ) | |||||
| validator.check_integer('group', group, out_channels, Rel.EQ) | |||||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | ||||
| kernel_size=self.kernel_size, | kernel_size=self.kernel_size, | ||||
| pad_mode=pad_mode, | pad_mode=pad_mode, | ||||
| @@ -19,7 +19,6 @@ from ... import nn | |||||
| from ... import ops | from ... import ops | ||||
| from ..._checkparam import ParamValidator as validator | from ..._checkparam import ParamValidator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...nn.layer import combined | |||||
| from ...nn.layer import quant | from ...nn.layer import quant | ||||
| _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | ||||
| @@ -123,13 +122,13 @@ class ConvertToQuantNetwork: | |||||
| subcell = cells[name] | subcell = cells[name] | ||||
| if subcell == network: | if subcell == network: | ||||
| continue | continue | ||||
| elif isinstance(subcell, combined.Conv2d): | |||||
| elif isinstance(subcell, quant.Conv2dBnAct): | |||||
| prefix = subcell.param_prefix | prefix = subcell.param_prefix | ||||
| new_subcell = self._convert_conv(subcell) | new_subcell = self._convert_conv(subcell) | ||||
| new_subcell.update_parameters_name(prefix + '.') | new_subcell.update_parameters_name(prefix + '.') | ||||
| network.insert_child_to_cell(name, new_subcell) | network.insert_child_to_cell(name, new_subcell) | ||||
| change = True | change = True | ||||
| elif isinstance(subcell, combined.Dense): | |||||
| elif isinstance(subcell, quant.DenseBnAct): | |||||
| prefix = subcell.param_prefix | prefix = subcell.param_prefix | ||||
| new_subcell = self._convert_dense(subcell) | new_subcell = self._convert_dense(subcell) | ||||
| new_subcell.update_parameters_name(prefix + '.') | new_subcell.update_parameters_name(prefix + '.') | ||||
| @@ -159,7 +158,7 @@ class ConvertToQuantNetwork: | |||||
| def _convert_conv(self, subcell): | def _convert_conv(self, subcell): | ||||
| """ | """ | ||||
| convet conv cell to combine cell | |||||
| convet conv cell to quant cell | |||||
| """ | """ | ||||
| conv_inner = subcell.conv | conv_inner = subcell.conv | ||||
| bn_inner = subcell.batchnorm | bn_inner = subcell.batchnorm | ||||
| @@ -1,6 +1,5 @@ | |||||
| """mobile net v2""" | """mobile net v2""" | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.nn.layer import combined | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -14,11 +13,11 @@ def _conv_bn(in_channel, | |||||
| stride=1): | stride=1): | ||||
| """Get a conv2d batchnorm and relu layer.""" | """Get a conv2d batchnorm and relu layer.""" | ||||
| return nn.SequentialCell( | return nn.SequentialCell( | ||||
| [combined.Conv2d(in_channel, | |||||
| out_channel, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| batchnorm=True)]) | |||||
| [nn.Conv2dBnAct(in_channel, | |||||
| out_channel, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| batchnorm=True)]) | |||||
| class InvertedResidual(nn.Cell): | class InvertedResidual(nn.Cell): | ||||
| @@ -31,30 +30,30 @@ class InvertedResidual(nn.Cell): | |||||
| self.use_res_connect = self.stride == 1 and inp == oup | self.use_res_connect = self.stride == 1 and inp == oup | ||||
| if expend_ratio == 1: | if expend_ratio == 1: | ||||
| self.conv = nn.SequentialCell([ | self.conv = nn.SequentialCell([ | ||||
| combined.Conv2d(hidden_dim, | |||||
| hidden_dim, | |||||
| 3, | |||||
| stride, | |||||
| group=hidden_dim, | |||||
| batchnorm=True, | |||||
| activation='relu6'), | |||||
| combined.Conv2d(hidden_dim, oup, 1, 1, | |||||
| batchnorm=True) | |||||
| nn.Conv2dBnAct(hidden_dim, | |||||
| hidden_dim, | |||||
| 3, | |||||
| stride, | |||||
| group=hidden_dim, | |||||
| batchnorm=True, | |||||
| activation='relu6'), | |||||
| nn.Conv2dBnAct(hidden_dim, oup, 1, 1, | |||||
| batchnorm=True) | |||||
| ]) | ]) | ||||
| else: | else: | ||||
| self.conv = nn.SequentialCell([ | self.conv = nn.SequentialCell([ | ||||
| combined.Conv2d(inp, hidden_dim, 1, 1, | |||||
| batchnorm=True, | |||||
| activation='relu6'), | |||||
| combined.Conv2d(hidden_dim, | |||||
| hidden_dim, | |||||
| 3, | |||||
| stride, | |||||
| group=hidden_dim, | |||||
| batchnorm=True, | |||||
| activation='relu6'), | |||||
| combined.Conv2d(hidden_dim, oup, 1, 1, | |||||
| batchnorm=True) | |||||
| nn.Conv2dBnAct(inp, hidden_dim, 1, 1, | |||||
| batchnorm=True, | |||||
| activation='relu6'), | |||||
| nn.Conv2dBnAct(hidden_dim, | |||||
| hidden_dim, | |||||
| 3, | |||||
| stride, | |||||
| group=hidden_dim, | |||||
| batchnorm=True, | |||||
| activation='relu6'), | |||||
| nn.Conv2dBnAct(hidden_dim, oup, 1, 1, | |||||
| batchnorm=True) | |||||
| ]) | ]) | ||||
| self.add = P.TensorAdd() | self.add = P.TensorAdd() | ||||
| @@ -99,7 +98,7 @@ class MobileNetV2(nn.Cell): | |||||
| self.features = nn.SequentialCell(features) | self.features = nn.SequentialCell(features) | ||||
| self.mean = P.ReduceMean(keep_dims=False) | self.mean = P.ReduceMean(keep_dims=False) | ||||
| self.classifier = combined.Dense(self.last_channel, num_class) | |||||
| self.classifier = nn.DenseBnAct(self.last_channel, num_class) | |||||
| def construct(self, input_x): | def construct(self, input_x): | ||||
| out = input_x | out = input_x | ||||
| @@ -15,7 +15,7 @@ | |||||
| """ tests for quant """ | """ tests for quant """ | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.nn.layer import combined | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| @@ -37,12 +37,11 @@ class LeNet5(nn.Cell): | |||||
| def __init__(self, num_class=10): | def __init__(self, num_class=10): | ||||
| super(LeNet5, self).__init__() | super(LeNet5, self).__init__() | ||||
| self.num_class = num_class | self.num_class = num_class | ||||
| self.conv1 = combined.Conv2d( | |||||
| 1, 6, kernel_size=5, batchnorm=True, activation='relu6') | |||||
| self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu') | |||||
| self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu') | |||||
| self.fc2 = combined.Dense(120, 84, activation='relu') | |||||
| self.fc3 = combined.Dense(84, self.num_class) | |||||
| self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6') | |||||
| self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu') | |||||
| self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') | |||||
| self.fc2 = nn.DenseBnAct(120, 84, activation='relu') | |||||
| self.fc3 = nn.DenseBnAct(84, self.num_class) | |||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | ||||
| self.flattern = nn.Flatten() | self.flattern = nn.Flatten() | ||||