Merge pull request !1084 from SanjayChan/04quanttags/v0.3.0-alpha
| @@ -97,7 +97,7 @@ class Cell: | |||
| After invoked, can get all the cell's children's name prefix by '_param_prefix'. | |||
| """ | |||
| cells = self.cells_and_names | |||
| cells = self.cells_and_names() | |||
| for cell_name, cell in cells: | |||
| cell._param_prefix = cell_name | |||
| @@ -0,0 +1,182 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Use combination of Conv, Dense, Relu, Batchnorm.""" | |||
| from .normalization import BatchNorm2d | |||
| from .activation import get_activation | |||
| from ..cell import Cell | |||
| from . import conv, basic | |||
| from ..._checkparam import ParamValidator as validator | |||
| __all__ = ['Conv2d', 'Dense'] | |||
| class Conv2d(Cell): | |||
| r""" | |||
| A combination of convolution, Batchnorm, activation layer. | |||
| For a more Detailed overview of Conv2d op. | |||
| Args: | |||
| in_channels (int): The number of input channel :math:`C_{in}`. | |||
| out_channels (int): The number of output channel :math:`C_{out}`. | |||
| kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height | |||
| and width of the 2D convolution window. Single int means the value if for both height and width of | |||
| the kernel. A tuple of 2 ints means the first value is for the height and the other is for the | |||
| width of the kernel. | |||
| stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be | |||
| greater or equal to 1 but bounded by the height and width of the input. Default: 1. | |||
| pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||
| padding (int): Implicit paddings on both sides of the input. Default: 0. | |||
| dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`, | |||
| there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater | |||
| or equal to 1 and bounded by the height and width of the input. Default: 1. | |||
| group (int): Split filter into groups, `in_ channels` and `out_channels` should be | |||
| divisible by the number of groups. Default: 1. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||
| It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, | |||
| values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well | |||
| as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' | |||
| and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of | |||
| Initializer for more details. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible | |||
| Initializer and string are the same as 'weight_init'. Refer to the values of | |||
| Initializer for more details. Default: 'zeros'. | |||
| batchnorm (bool): Specifies to used batchnorm or not. Default: None. | |||
| activation (string): Specifies activation type. The optional values are as following: | |||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU') | |||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | |||
| >>> net(input).shape() | |||
| (1, 240, 1024, 640) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| pad_mode='same', | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| has_bias=False, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| batchnorm=None, | |||
| activation=None): | |||
| super(Conv2d, self).__init__() | |||
| self.conv = conv.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| pad_mode, | |||
| padding, | |||
| dilation, | |||
| group, | |||
| has_bias, | |||
| weight_init, | |||
| bias_init) | |||
| self.has_bn = batchnorm is not None | |||
| self.has_act = activation is not None | |||
| self.batchnorm = batchnorm | |||
| if batchnorm is True: | |||
| self.batchnorm = BatchNorm2d(out_channels) | |||
| elif batchnorm is not None: | |||
| validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) | |||
| self.activation = get_activation(activation) | |||
| def construct(self, x): | |||
| x = self.conv(x) | |||
| if self.has_bn: | |||
| x = self.batchnorm(x) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| return x | |||
| class Dense(Cell): | |||
| r""" | |||
| A combination of Dense, Batchnorm, activation layer. | |||
| For a more Detailed overview of Dense op. | |||
| Args: | |||
| in_channels (int): The number of channels in the input space. | |||
| out_channels (int): The number of channels in the output space. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||
| is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||
| batchnorm (bool): Specifies to used batchnorm or not. Default: None. | |||
| activation (string): Specifies activation type. The optional values are as following: | |||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, out\_channels)`. | |||
| Examples: | |||
| >>> net = nn.Dense(3, 4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||
| >>> net(input) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bias=True, | |||
| batchnorm=None, | |||
| activation=None): | |||
| super(Dense, self).__init__() | |||
| self.dense = basic.Dense( | |||
| in_channels, | |||
| out_channels, | |||
| weight_init, | |||
| bias_init, | |||
| has_bias) | |||
| self.has_bn = batchnorm is not None | |||
| self.has_act = activation is not None | |||
| if batchnorm is True: | |||
| self.batchnorm = BatchNorm2d(out_channels) | |||
| elif batchnorm is not None: | |||
| validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) | |||
| self.activation = get_activation(activation) | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| if self.has_bn: | |||
| x = self.batchnorm(x) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| return x | |||
| @@ -191,6 +191,8 @@ class Conv2dBatchNormQuant(Cell): | |||
| stride, | |||
| pad_mode, | |||
| padding=0, | |||
| dilation=1, | |||
| group=1, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| weight_init=None, | |||
| @@ -198,7 +200,6 @@ class Conv2dBatchNormQuant(Cell): | |||
| gamma_init=None, | |||
| mean_init=None, | |||
| var_init=None, | |||
| group=1, | |||
| quant_delay=0, | |||
| freeze_bn=100000, | |||
| fake=True, | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| quantization. | |||
| User can use aware quantization to train a model. Mindspore supports quantization aware training, | |||
| which models quantization errors in both the forward and backward passes using fake-quantization | |||
| ops. Note that the entire computation is carried out in floating point. At the end of quantization | |||
| aware training, Mindspore provides conversion functions to convert the trained model into lower precision. | |||
| """ | |||
| from .quant import convert_quant_network | |||
| __all__ = ["convert_quant_network"] | |||
| @@ -0,0 +1,262 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """aware quantization.""" | |||
| import re | |||
| from ... import nn | |||
| from ... import ops | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Rel | |||
| from ...nn.layer import combined | |||
| from ...nn.layer import quant | |||
| _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | |||
| nn.ReLU6: quant.ReLU6Quant, | |||
| nn.HSigmoid: quant.HSigmoidQuant, | |||
| nn.HSwish: quant.HSwishQuant} | |||
| class _AddFakeQuantInputOutput(nn.Cell): | |||
| """ | |||
| Add FakeQuant at input and output of the Network. Only support one input and one output case. | |||
| """ | |||
| def __init__(self, network, quant_delay=0): | |||
| super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.fake_quant_input = quant.FakeQuantWithMinMax( | |||
| min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) | |||
| self.fake_quant_input.update_parameters_name('fake_quant_input') | |||
| self.fake_quant_output = quant.FakeQuantWithMinMax( | |||
| min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) | |||
| self.fake_quant_output.update_parameters_name('fake_quant_output') | |||
| def construct(self, data): | |||
| data = self.fake_quant_input(data) | |||
| output = self.network(data) | |||
| output = self.fake_quant_output(output) | |||
| return output | |||
| class _AddFakeQuantAfterSubCell(nn.Cell): | |||
| """ | |||
| Add FakeQuant after of the sub Cell. | |||
| """ | |||
| def __init__(self, subcell, quant_delay=0, num_bits=8): | |||
| super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) | |||
| self.subcell = subcell | |||
| self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True) | |||
| def construct(self, *data): | |||
| output = self.subcell(*data) | |||
| output = self.fake_quant_act(output) | |||
| return output | |||
| class ConvertToQuantNetwork: | |||
| """ | |||
| Convert network to quantization aware network | |||
| """ | |||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||
| def __init__(self, | |||
| network, | |||
| quant_delay=0, | |||
| bn_fold=False, | |||
| freeze_bn=0, | |||
| weight_bits=8, | |||
| act_bits=8, | |||
| per_channel=False, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| self.network = validator.check_isinstance( | |||
| 'network', network, (nn.Cell,)) | |||
| self.quant_delay = validator.check_integer( | |||
| "quant delay", quant_delay, 0, Rel.GE) | |||
| self.freeze_bn = validator.check_integer( | |||
| "freeze bn", freeze_bn, 0, Rel.GE) | |||
| self.weight_bits = validator.check_integer( | |||
| "weights bit", weight_bits, 0, Rel.GE) | |||
| self.act_bits = validator.check_integer( | |||
| "activations bit", act_bits, 0, Rel.GE) | |||
| self.bn_fold = validator.check_bool("bn fold", bn_fold) | |||
| self.per_channel = validator.check_bool("per channel", per_channel) | |||
| self.symmetric = validator.check_bool("symmetric", symmetric) | |||
| self.narrow_range = validator.check_bool("narrow range", narrow_range) | |||
| def _convert_op_name(self, name): | |||
| pattern = re.compile(r'([A-Z]{1})') | |||
| name_new = re.sub(pattern, r'_\1', name).lower() | |||
| if name_new[0] == '_': | |||
| name_new = name_new[1:] | |||
| return name_new | |||
| def run(self): | |||
| self.network.update_cell_prefix() | |||
| network = self._convert_subcells2quant(self.network) | |||
| return network | |||
| def _convert_subcells2quant(self, network): | |||
| """ | |||
| convet sub cell to quant cell | |||
| """ | |||
| cells = network.name_cells() | |||
| change = False | |||
| for name in cells: | |||
| subcell = cells[name] | |||
| if subcell == network: | |||
| continue | |||
| elif isinstance(subcell, combined.Conv2d): | |||
| prefix = subcell.param_prefix | |||
| new_subcell = self._convert_conv(subcell) | |||
| new_subcell.update_parameters_name(prefix + '.') | |||
| network.insert_child_to_cell(name, new_subcell) | |||
| change = True | |||
| elif isinstance(subcell, combined.Dense): | |||
| prefix = subcell.param_prefix | |||
| new_subcell = self._convert_dense(subcell) | |||
| new_subcell.update_parameters_name(prefix + '.') | |||
| network.insert_child_to_cell(name, new_subcell) | |||
| change = True | |||
| else: | |||
| self._convert_subcells2quant(subcell) | |||
| if isinstance(network, nn.SequentialCell) and change: | |||
| network.cell_list = list(network.cells()) | |||
| # tensoradd to tensoradd quant | |||
| add_list = [] | |||
| for name in network.__dict__: | |||
| if name[0] == '_': | |||
| continue | |||
| attr = network.__dict__[name] | |||
| if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__: | |||
| add_list.append((name, attr)) | |||
| for name, prim_op in add_list: | |||
| prefix = name | |||
| add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant() | |||
| prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) | |||
| add_quant.update_parameters_name(prefix + '.') | |||
| del network.__dict__[name] | |||
| network.insert_child_to_cell(name, add_quant) | |||
| return network | |||
| def _convert_conv(self, subcell): | |||
| """ | |||
| convet conv cell to combine cell | |||
| """ | |||
| conv_inner = subcell.conv | |||
| bn_inner = subcell.batchnorm | |||
| if subcell.batchnorm is not None and self.bn_fold: | |||
| conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, | |||
| conv_inner.out_channels, | |||
| kernel_size=conv_inner.kernel_size, | |||
| stride=conv_inner.stride, | |||
| pad_mode=conv_inner.pad_mode, | |||
| padding=conv_inner.padding, | |||
| dilation=conv_inner.dilation, | |||
| group=conv_inner.group, | |||
| eps=bn_inner.eps, | |||
| momentum=bn_inner.momentum, | |||
| quant_delay=self.quant_delay, | |||
| freeze_bn=self.freeze_bn, | |||
| per_channel=self.per_channel, | |||
| num_bits=self.weight_bits, | |||
| fake=True, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range) | |||
| del subcell.batchnorm | |||
| subcell.batchnorm = None | |||
| subcell.has_bn = False | |||
| else: | |||
| conv_inner = quant.Conv2dQuant(conv_inner.in_channels, | |||
| conv_inner.out_channels, | |||
| kernel_size=conv_inner.kernel_size, | |||
| stride=conv_inner.stride, | |||
| pad_mode=conv_inner.pad_mode, | |||
| padding=conv_inner.padding, | |||
| dilation=conv_inner.dilation, | |||
| group=conv_inner.group, | |||
| has_bias=conv_inner.has_bias, | |||
| quant_delay=self.quant_delay, | |||
| per_channel=self.per_channel, | |||
| num_bits=self.weight_bits, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range) | |||
| subcell.conv = conv_inner | |||
| if subcell.activation is not None: | |||
| subcell.activation = self._convert_activation(subcell.activation) | |||
| else: | |||
| subcell = _AddFakeQuantAfterSubCell(subcell) | |||
| return subcell | |||
| def _convert_dense(self, subcell): | |||
| """ | |||
| convert dense cell to combine dense cell | |||
| """ | |||
| dense_inner = subcell.dense | |||
| dense_inner = quant.DenseQuant(dense_inner.in_channels, | |||
| dense_inner.out_channels, | |||
| has_bias=dense_inner.has_bias, | |||
| quant_delay=self.quant_delay, | |||
| per_channel=self.per_channel, | |||
| num_bits=self.weight_bits) | |||
| subcell.dense = dense_inner | |||
| if subcell.activation is not None: | |||
| subcell.activation = self._convert_activation(subcell.activation) | |||
| return subcell | |||
| def _convert_activation(self, activation): | |||
| act_class = activation.__class__ | |||
| if act_class not in _ACTIVATION_MAP: | |||
| raise ValueError( | |||
| "Unsupported activation in auto Quant: ", act_class) | |||
| return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay) | |||
| def convert_quant_network(network, | |||
| quant_delay=0, | |||
| bn_fold=False, | |||
| freeze_bn=0, | |||
| weight_bits=8, | |||
| act_bits=8, | |||
| per_channel=False, | |||
| symmetric=False, | |||
| narrow_range=False | |||
| ): | |||
| r""" | |||
| Create aware quantizaiton training network. | |||
| Args: | |||
| network (Cell): Obtain a pipeline through network for saving graph summary. | |||
| quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. | |||
| bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | |||
| freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0. | |||
| weight_bits (int): Number of bits to use for quantizing weights. Default: 8. | |||
| act_bits (int): Number of bits to use for quantizing activations. Default: 8. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| returns: | |||
| Cell, Network which has change to aware quantization training network. | |||
| """ | |||
| net = ConvertToQuantNetwork( | |||
| network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range) | |||
| return net.run() | |||
| @@ -0,0 +1,100 @@ | |||
| """MobileNetV2""" | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| def make_divisible(input_x, div_by=8): | |||
| return int((input_x + div_by) // div_by) | |||
| def _conv_bn(in_channel, | |||
| out_channel, | |||
| ksize, | |||
| stride=1): | |||
| """Get a conv2d batchnorm and relu layer.""" | |||
| return nn.SequentialCell( | |||
| [nn.Conv2d(in_channel, | |||
| out_channel, | |||
| kernel_size=ksize, | |||
| stride=stride), | |||
| nn.BatchNorm2d(out_channel)]) | |||
| class InvertedResidual(nn.Cell): | |||
| def __init__(self, inp, oup, stride, expend_ratio): | |||
| super(InvertedResidual, self).__init__() | |||
| self.stride = stride | |||
| assert stride in [1, 2] | |||
| hidden_dim = int(inp * expend_ratio) | |||
| self.use_res_connect = self.stride == 1 and inp == oup | |||
| if expend_ratio == 1: | |||
| self.conv = nn.SequentialCell([ | |||
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), | |||
| nn.BatchNorm2d(hidden_dim), | |||
| nn.ReLU6(), | |||
| nn.Conv2d(hidden_dim, oup, 1, 1), | |||
| nn.BatchNorm2d(oup) | |||
| ]) | |||
| else: | |||
| self.conv = nn.SequentialCell([ | |||
| nn.Conv2d(inp, hidden_dim, 1, 1), | |||
| nn.BatchNorm2d(hidden_dim), | |||
| nn.ReLU6(), | |||
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), | |||
| nn.BatchNorm2d(hidden_dim), | |||
| nn.ReLU6(), | |||
| nn.Conv2d(hidden_dim, oup, 1, 1), | |||
| nn.BatchNorm2d(oup) | |||
| ]) | |||
| def construct(self, input_x): | |||
| out = self.conv(input_x) | |||
| if self.use_res_connect: | |||
| out = input_x + out | |||
| return out | |||
| class MobileNetV2(nn.Cell): | |||
| def __init__(self, num_class=1000, input_size=224, width_mul=1.): | |||
| super(MobileNetV2, self).__init__() | |||
| block = InvertedResidual | |||
| input_channel = 32 | |||
| last_channel = 1280 | |||
| inverted_residual_setting = [ | |||
| [1, 16, 1, 1], | |||
| [6, 24, 2, 2], | |||
| [6, 32, 3, 2], | |||
| [6, 64, 4, 2], | |||
| [6, 96, 3, 1], | |||
| [6, 160, 3, 2], | |||
| [6, 230, 1, 1], | |||
| ] | |||
| if width_mul > 1.0: | |||
| last_channel = make_divisible(last_channel * width_mul) | |||
| self.last_channel = last_channel | |||
| features = [_conv_bn(3, input_channel, 3, 2)] | |||
| for t, c, n, s in inverted_residual_setting: | |||
| out_channel = make_divisible(c * width_mul) if t > 1 else c | |||
| for i in range(n): | |||
| if i == 0: | |||
| features.append(block(input_channel, out_channel, s, t)) | |||
| else: | |||
| features.append(block(input_channel, out_channel, 1, t)) | |||
| input_channel = out_channel | |||
| features.append(_conv_bn(input_channel, self.last_channel, 1)) | |||
| self.features = nn.SequentialCell(features) | |||
| self.mean = P.ReduceMean(keep_dims=False) | |||
| self.classifier = nn.Dense(self.last_channel, num_class) | |||
| def construct(self, input_x): | |||
| out = input_x | |||
| out = self.features(out) | |||
| out = self.mean(out, (2, 3)) | |||
| out = self.classifier(out) | |||
| return out | |||
| @@ -0,0 +1,108 @@ | |||
| """mobile net v2""" | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.layer import combined | |||
| def make_divisible(input_x, div_by=8): | |||
| return int((input_x + div_by) // div_by) | |||
| def _conv_bn(in_channel, | |||
| out_channel, | |||
| ksize, | |||
| stride=1): | |||
| """Get a conv2d batchnorm and relu layer.""" | |||
| return nn.SequentialCell( | |||
| [combined.Conv2d(in_channel, | |||
| out_channel, | |||
| kernel_size=ksize, | |||
| stride=stride, | |||
| batchnorm=True)]) | |||
| class InvertedResidual(nn.Cell): | |||
| def __init__(self, inp, oup, stride, expend_ratio): | |||
| super(InvertedResidual, self).__init__() | |||
| self.stride = stride | |||
| assert stride in [1, 2] | |||
| hidden_dim = int(inp * expend_ratio) | |||
| self.use_res_connect = self.stride == 1 and inp == oup | |||
| if expend_ratio == 1: | |||
| self.conv = nn.SequentialCell([ | |||
| combined.Conv2d(hidden_dim, | |||
| hidden_dim, | |||
| 3, | |||
| stride, | |||
| group=hidden_dim, | |||
| batchnorm=True, | |||
| activation='relu6'), | |||
| combined.Conv2d(hidden_dim, oup, 1, 1, | |||
| batchnorm=True) | |||
| ]) | |||
| else: | |||
| self.conv = nn.SequentialCell([ | |||
| combined.Conv2d(inp, hidden_dim, 1, 1, | |||
| batchnorm=True, | |||
| activation='relu6'), | |||
| combined.Conv2d(hidden_dim, | |||
| hidden_dim, | |||
| 3, | |||
| stride, | |||
| group=hidden_dim, | |||
| batchnorm=True, | |||
| activation='relu6'), | |||
| combined.Conv2d(hidden_dim, oup, 1, 1, | |||
| batchnorm=True) | |||
| ]) | |||
| self.add = P.TensorAdd() | |||
| def construct(self, input_x): | |||
| out = self.conv(input_x) | |||
| if self.use_res_connect: | |||
| out = self.add(input_x, out) | |||
| return out | |||
| class MobileNetV2(nn.Cell): | |||
| def __init__(self, num_class=1000, input_size=224, width_mul=1.): | |||
| super(MobileNetV2, self).__init__() | |||
| block = InvertedResidual | |||
| input_channel = 32 | |||
| last_channel = 1280 | |||
| inverted_residual_setting = [ | |||
| [1, 16, 1, 1], | |||
| [6, 24, 2, 2], | |||
| [6, 32, 3, 2], | |||
| [6, 64, 4, 2], | |||
| [6, 96, 3, 1], | |||
| [6, 160, 3, 2], | |||
| [6, 230, 1, 1], | |||
| ] | |||
| if width_mul > 1.0: | |||
| last_channel = make_divisible(last_channel * width_mul) | |||
| self.last_channel = last_channel | |||
| features = [_conv_bn(3, input_channel, 3, 2)] | |||
| for t, c, n, s in inverted_residual_setting: | |||
| out_channel = make_divisible(c * width_mul) if t > 1 else c | |||
| for i in range(n): | |||
| if i == 0: | |||
| features.append(block(input_channel, out_channel, s, t)) | |||
| else: | |||
| features.append(block(input_channel, out_channel, 1, t)) | |||
| input_channel = out_channel | |||
| features.append(_conv_bn(input_channel, self.last_channel, 1)) | |||
| self.features = nn.SequentialCell(features) | |||
| self.mean = P.ReduceMean(keep_dims=False) | |||
| self.classifier = combined.Dense(self.last_channel, num_class) | |||
| def construct(self, input_x): | |||
| out = input_x | |||
| out = self.features(out) | |||
| out = self.mean(out, (2, 3)) | |||
| out = self.classifier(out) | |||
| return out | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ tests for quant """ | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.train.quant import quant as qat | |||
| from mindspore import nn | |||
| import mindspore.ops.operations as P | |||
| from mindspore.nn.layer import combined | |||
| import mindspore.context as context | |||
| from mobilenetv2_combined import MobileNetV2 | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class LeNet5(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| Args: | |||
| num_class (int): Num classes. Default: 10. | |||
| Returns: | |||
| Tensor, output tensor | |||
| Examples: | |||
| >>> LeNet(num_class=10) | |||
| """ | |||
| def __init__(self, num_class=10): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.conv1 = combined.Conv2d( | |||
| 1, 6, kernel_size=5, batchnorm=True, activation='relu6') | |||
| self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu') | |||
| self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu') | |||
| self.fc2 = combined.Dense(120, 84, activation='relu') | |||
| self.fc3 = combined.Dense(84, self.num_class) | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flattern = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.flattern(x) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| def test_qat_lenet(): | |||
| net = LeNet5() | |||
| net = qat.convert_quant_network( | |||
| net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) | |||
| def test_qat_mobile(): | |||
| net = MobileNetV2() | |||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | |||
| net = qat.convert_quant_network( | |||
| net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) | |||
| net(img) | |||
| def test_qat_mobile_train(): | |||
| net = MobileNetV2(num_class=10) | |||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | |||
| label = Tensor(np.ones((1, 10)).astype(np.float32)) | |||
| net = qat.convert_quant_network( | |||
| net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') | |||
| optimizer = nn.Momentum(net.trainable_params(), | |||
| learning_rate=0.1, momentum=0.9) | |||
| net = nn.WithLossCell(net, loss) | |||
| net = nn.TrainOneStepCell(net, optimizer) | |||
| net(img, label) | |||