Merge pull request !1745 from SanjayChan/mastertags/v0.5.0-beta
| @@ -39,10 +39,10 @@ config_gpu = ed({ | |||||
| "num_classes": 1000, | "num_classes": 1000, | ||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| "batch_size": 64, | |||||
| "batch_size": 150, | |||||
| "epoch_size": 200, | "epoch_size": 200, | ||||
| "warmup_epochs": 4, | |||||
| "lr": 0.5, | |||||
| "warmup_epochs": 0, | |||||
| "lr": 0.8, | |||||
| "momentum": 0.9, | "momentum": 0.9, | ||||
| "weight_decay": 4e-5, | "weight_decay": 4e-5, | ||||
| "label_smooth": 0.1, | "label_smooth": 0.1, | ||||
| @@ -20,20 +20,10 @@ from mindspore.ops.operations import TensorAdd | |||||
| from mindspore import Parameter, Tensor | from mindspore import Parameter, Tensor | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| __all__ = ['MobileNetV2', 'mobilenet_v2'] | |||||
| __all__ = ['mobilenet_v2'] | |||||
| def _make_divisible(v, divisor, min_value=None): | def _make_divisible(v, divisor, min_value=None): | ||||
| """ | |||||
| This function is taken from the original tf repo. | |||||
| It ensures that all layers have a channel number that is divisible by 8 | |||||
| It can be seen here: | |||||
| https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py | |||||
| :param v: | |||||
| :param divisor: | |||||
| :param min_value: | |||||
| :return: | |||||
| """ | |||||
| if min_value is None: | if min_value is None: | ||||
| min_value = divisor | min_value = divisor | ||||
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | ||||
| @@ -55,6 +45,7 @@ class GlobalAvgPooling(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> GlobalAvgPooling() | >>> GlobalAvgPooling() | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GlobalAvgPooling, self).__init__() | super(GlobalAvgPooling, self).__init__() | ||||
| self.mean = P.ReduceMean(keep_dims=False) | self.mean = P.ReduceMean(keep_dims=False) | ||||
| @@ -82,6 +73,7 @@ class DepthwiseConv(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) | >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) | ||||
| """ | """ | ||||
| def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | ||||
| super(DepthwiseConv, self).__init__() | super(DepthwiseConv, self).__init__() | ||||
| self.has_bias = has_bias | self.has_bias = has_bias | ||||
| @@ -126,14 +118,19 @@ class ConvBNReLU(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | ||||
| """ | """ | ||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||||
| def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||||
| super(ConvBNReLU, self).__init__() | super(ConvBNReLU, self).__init__() | ||||
| padding = (kernel_size - 1) // 2 | padding = (kernel_size - 1) // 2 | ||||
| if groups == 1: | if groups == 1: | ||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', | |||||
| padding=padding) | |||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) | |||||
| else: | else: | ||||
| conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) | |||||
| if platform == "Ascend": | |||||
| conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) | |||||
| elif platform == "GPU": | |||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, | |||||
| group=in_planes, pad_mode='pad', padding=padding) | |||||
| layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] | layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] | ||||
| self.features = nn.SequentialCell(layers) | self.features = nn.SequentialCell(layers) | ||||
| @@ -158,7 +155,8 @@ class InvertedResidual(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> ResidualBlock(3, 256, 1, 1) | >>> ResidualBlock(3, 256, 1, 1) | ||||
| """ | """ | ||||
| def __init__(self, inp, oup, stride, expand_ratio): | |||||
| def __init__(self, platform, inp, oup, stride, expand_ratio): | |||||
| super(InvertedResidual, self).__init__() | super(InvertedResidual, self).__init__() | ||||
| assert stride in [1, 2] | assert stride in [1, 2] | ||||
| @@ -167,12 +165,14 @@ class InvertedResidual(nn.Cell): | |||||
| layers = [] | layers = [] | ||||
| if expand_ratio != 1: | if expand_ratio != 1: | ||||
| layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | |||||
| layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1)) | |||||
| layers.extend([ | layers.extend([ | ||||
| # dw | # dw | ||||
| ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | |||||
| ConvBNReLU(platform, hidden_dim, hidden_dim, | |||||
| stride=stride, groups=hidden_dim), | |||||
| # pw-linear | # pw-linear | ||||
| nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), | |||||
| nn.Conv2d(hidden_dim, oup, kernel_size=1, | |||||
| stride=1, has_bias=False), | |||||
| nn.BatchNorm2d(oup), | nn.BatchNorm2d(oup), | ||||
| ]) | ]) | ||||
| self.conv = nn.SequentialCell(layers) | self.conv = nn.SequentialCell(layers) | ||||
| @@ -203,7 +203,8 @@ class MobileNetV2(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> MobileNetV2(num_classes=1000) | >>> MobileNetV2(num_classes=1000) | ||||
| """ | """ | ||||
| def __init__(self, num_classes=1000, width_mult=1., | |||||
| def __init__(self, platform, num_classes=1000, width_mult=1., | |||||
| has_dropout=False, inverted_residual_setting=None, round_nearest=8): | has_dropout=False, inverted_residual_setting=None, round_nearest=8): | ||||
| super(MobileNetV2, self).__init__() | super(MobileNetV2, self).__init__() | ||||
| block = InvertedResidual | block = InvertedResidual | ||||
| @@ -226,16 +227,16 @@ class MobileNetV2(nn.Cell): | |||||
| # building first layer | # building first layer | ||||
| input_channel = _make_divisible(input_channel * width_mult, round_nearest) | input_channel = _make_divisible(input_channel * width_mult, round_nearest) | ||||
| self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) | self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) | ||||
| features = [ConvBNReLU(3, input_channel, stride=2)] | |||||
| features = [ConvBNReLU(platform, 3, input_channel, stride=2)] | |||||
| # building inverted residual blocks | # building inverted residual blocks | ||||
| for t, c, n, s in self.cfgs: | for t, c, n, s in self.cfgs: | ||||
| output_channel = _make_divisible(c * width_mult, round_nearest) | output_channel = _make_divisible(c * width_mult, round_nearest) | ||||
| for i in range(n): | for i in range(n): | ||||
| stride = s if i == 0 else 1 | stride = s if i == 0 else 1 | ||||
| features.append(block(input_channel, output_channel, stride, expand_ratio=t)) | |||||
| features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t)) | |||||
| input_channel = output_channel | input_channel = output_channel | ||||
| # building last several layers | # building last several layers | ||||
| features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) | |||||
| features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1)) | |||||
| # make it nn.CellList | # make it nn.CellList | ||||
| self.features = nn.SequentialCell(features) | self.features = nn.SequentialCell(features) | ||||
| # mobilenet head | # mobilenet head | ||||
| @@ -268,14 +269,19 @@ class MobileNetV2(nn.Cell): | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | ||||
| m.weight.data.shape()).astype("float32"))) | m.weight.data.shape()).astype("float32"))) | ||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | elif isinstance(m, nn.BatchNorm2d): | ||||
| m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) | |||||
| m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) | |||||
| m.gamma.set_parameter_data( | |||||
| Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) | |||||
| m.beta.set_parameter_data( | |||||
| Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | elif isinstance(m, nn.Dense): | ||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape()).astype("float32"))) | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | |||||
| 0, 0.01, m.weight.data.shape()).astype("float32"))) | |||||
| if m.bias is not None: | if m.bias is not None: | ||||
| m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| def mobilenet_v2(**kwargs): | def mobilenet_v2(**kwargs): | ||||
| @@ -205,7 +205,7 @@ if __name__ == '__main__': | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | ||||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | keep_checkpoint_max=config_gpu.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint( | ckpt_cb = ModelCheckpoint( | ||||
| prefix="mobilenet", directory=config_gpu.save_checkpoint_path, config=config_ck) | |||||
| prefix="mobilenetV2", directory=config_gpu.save_checkpoint_path, config=config_ck) | |||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| # begine train | # begine train | ||||
| model.train(epoch_size, dataset, callbacks=cb) | model.train(epoch_size, dataset, callbacks=cb) | ||||
| @@ -265,7 +265,7 @@ if __name__ == '__main__': | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, | ||||
| keep_checkpoint_max=config_ascend.keep_checkpoint_max) | keep_checkpoint_max=config_ascend.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint( | ckpt_cb = ModelCheckpoint( | ||||
| prefix="mobilenet", directory=config_ascend.save_checkpoint_path, config=config_ck) | |||||
| prefix="mobilenetV2", directory=config_ascend.save_checkpoint_path, config=config_ck) | |||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| model.train(epoch_size, dataset, callbacks=cb) | model.train(epoch_size, dataset, callbacks=cb) | ||||
| else: | else: | ||||
| @@ -39,10 +39,10 @@ config_gpu = ed({ | |||||
| "num_classes": 1000, | "num_classes": 1000, | ||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| "batch_size": 64, | |||||
| "epoch_size": 300, | |||||
| "batch_size": 150, | |||||
| "epoch_size": 370, | |||||
| "warmup_epochs": 4, | "warmup_epochs": 4, | ||||
| "lr": 0.5, | |||||
| "lr": 1.54, | |||||
| "momentum": 0.9, | "momentum": 0.9, | ||||
| "weight_decay": 4e-5, | "weight_decay": 4e-5, | ||||
| "label_smooth": 0.1, | "label_smooth": 0.1, | ||||
| @@ -0,0 +1,390 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MobileNetV3 model define""" | |||||
| from functools import partial | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import Tensor | |||||
| __all__ = ['mobilenet_v3_large', | |||||
| 'mobilenet_v3_small'] | |||||
| def _make_divisible(x, divisor=8): | |||||
| return int(np.ceil(x * 1. / divisor) * divisor) | |||||
| class Activation(nn.Cell): | |||||
| """ | |||||
| Activation definition. | |||||
| Args: | |||||
| act_func(string): activation name. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| """ | |||||
| def __init__(self, act_func): | |||||
| super(Activation, self).__init__() | |||||
| if act_func == 'relu': | |||||
| self.act = nn.ReLU() | |||||
| elif act_func == 'relu6': | |||||
| self.act = nn.ReLU6() | |||||
| elif act_func in ('hsigmoid', 'hard_sigmoid'): | |||||
| self.act = nn.HSigmoid() | |||||
| elif act_func in ('hswish', 'hard_swish'): | |||||
| self.act = nn.HSwish() | |||||
| else: | |||||
| raise NotImplementedError | |||||
| def construct(self, x): | |||||
| return self.act(x) | |||||
| class GlobalAvgPooling(nn.Cell): | |||||
| """ | |||||
| Global avg pooling definition. | |||||
| Args: | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> GlobalAvgPooling() | |||||
| """ | |||||
| def __init__(self, keep_dims=False): | |||||
| super(GlobalAvgPooling, self).__init__() | |||||
| self.mean = P.ReduceMean(keep_dims=keep_dims) | |||||
| def construct(self, x): | |||||
| x = self.mean(x, (2, 3)) | |||||
| return x | |||||
| class SE(nn.Cell): | |||||
| """ | |||||
| SE warpper definition. | |||||
| Args: | |||||
| num_out (int): Output channel. | |||||
| ratio (int): middle output ratio. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> SE(4) | |||||
| """ | |||||
| def __init__(self, num_out, ratio=4): | |||||
| super(SE, self).__init__() | |||||
| num_mid = _make_divisible(num_out // ratio) | |||||
| self.pool = GlobalAvgPooling(keep_dims=True) | |||||
| self.conv1 = nn.Conv2d(in_channels=num_out, out_channels=num_mid, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.act1 = Activation('relu') | |||||
| self.conv2 = nn.Conv2d(in_channels=num_mid, out_channels=num_out, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.act2 = Activation('hsigmoid') | |||||
| self.mul = P.Mul() | |||||
| def construct(self, x): | |||||
| out = self.pool(x) | |||||
| out = self.conv1(out) | |||||
| out = self.act1(out) | |||||
| out = self.conv2(out) | |||||
| out = self.act2(out) | |||||
| out = self.mul(x, out) | |||||
| return out | |||||
| class Unit(nn.Cell): | |||||
| """ | |||||
| Unit warpper definition. | |||||
| Args: | |||||
| num_in (int): Input channel. | |||||
| num_out (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size. | |||||
| padding (int): Padding number. | |||||
| num_groups (int): Output num group. | |||||
| use_act (bool): Used activation or not. | |||||
| act_type (string): Activation type. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> Unit(3, 3) | |||||
| """ | |||||
| def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1, | |||||
| use_act=True, act_type='relu'): | |||||
| super(Unit, self).__init__() | |||||
| self.conv = nn.Conv2d(in_channels=num_in, | |||||
| out_channels=num_out, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| group=num_groups, | |||||
| has_bias=False, | |||||
| pad_mode='pad') | |||||
| self.bn = nn.BatchNorm2d(num_out) | |||||
| self.use_act = use_act | |||||
| self.act = Activation(act_type) if use_act else None | |||||
| def construct(self, x): | |||||
| out = self.conv(x) | |||||
| out = self.bn(out) | |||||
| if self.use_act: | |||||
| out = self.act(out) | |||||
| return out | |||||
| class ResUnit(nn.Cell): | |||||
| """ | |||||
| ResUnit warpper definition. | |||||
| Args: | |||||
| num_in (int): Input channel. | |||||
| num_mid (int): Middle channel. | |||||
| num_out (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size. | |||||
| act_type (str): Activation type. | |||||
| use_se (bool): Use SE warpper or not. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ResUnit(16, 3, 1, 1) | |||||
| """ | |||||
| def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): | |||||
| super(ResUnit, self).__init__() | |||||
| self.use_se = use_se | |||||
| self.first_conv = (num_out != num_mid) | |||||
| self.use_short_cut_conv = True | |||||
| if self.first_conv: | |||||
| self.expand = Unit(num_in, num_mid, kernel_size=1, | |||||
| stride=1, padding=0, act_type=act_type) | |||||
| else: | |||||
| self.expand = None | |||||
| self.conv1 = Unit(num_mid, num_mid, kernel_size=kernel_size, stride=stride, | |||||
| padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid) | |||||
| if use_se: | |||||
| self.se = SE(num_mid) | |||||
| self.conv2 = Unit(num_mid, num_out, kernel_size=1, stride=1, | |||||
| padding=0, act_type=act_type, use_act=False) | |||||
| if num_in != num_out or stride != 1: | |||||
| self.use_short_cut_conv = False | |||||
| self.add = P.TensorAdd() if self.use_short_cut_conv else None | |||||
| def construct(self, x): | |||||
| if self.first_conv: | |||||
| out = self.expand(x) | |||||
| else: | |||||
| out = x | |||||
| out = self.conv1(out) | |||||
| if self.use_se: | |||||
| out = self.se(out) | |||||
| out = self.conv2(out) | |||||
| if self.use_short_cut_conv: | |||||
| out = self.add(x, out) | |||||
| return out | |||||
| def _get_pad(self, kernel_size): | |||||
| """set the padding number""" | |||||
| pad = 0 | |||||
| if kernel_size == 1: | |||||
| pad = 0 | |||||
| elif kernel_size == 3: | |||||
| pad = 1 | |||||
| elif kernel_size == 5: | |||||
| pad = 2 | |||||
| elif kernel_size == 7: | |||||
| pad = 3 | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return pad | |||||
| class MobileNetV3(nn.Cell): | |||||
| """ | |||||
| MobileNetV3 architecture. | |||||
| Args: | |||||
| model_cfgs (Cell): number of classes. | |||||
| num_classes (int): Output number classes. | |||||
| multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1. | |||||
| final_drop (float): Dropout number. | |||||
| round_nearest (list): Channel round to . Default is 8. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> MobileNetV3(num_classes=1000) | |||||
| """ | |||||
| def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8): | |||||
| super(MobileNetV3, self).__init__() | |||||
| self.cfgs = model_cfgs['cfg'] | |||||
| self.inplanes = 16 | |||||
| self.features = [] | |||||
| first_conv_in_channel = 3 | |||||
| first_conv_out_channel = _make_divisible(multiplier * self.inplanes) | |||||
| self.features.append(nn.Conv2d(in_channels=first_conv_in_channel, | |||||
| out_channels=first_conv_out_channel, | |||||
| kernel_size=3, padding=1, stride=2, | |||||
| has_bias=False, pad_mode='pad')) | |||||
| self.features.append(nn.BatchNorm2d(first_conv_out_channel)) | |||||
| self.features.append(Activation('hswish')) | |||||
| for layer_cfg in self.cfgs: | |||||
| self.features.append(self._make_layer(kernel_size=layer_cfg[0], | |||||
| exp_ch=_make_divisible(multiplier * layer_cfg[1]), | |||||
| out_channel=_make_divisible(multiplier * layer_cfg[2]), | |||||
| use_se=layer_cfg[3], | |||||
| act_func=layer_cfg[4], | |||||
| stride=layer_cfg[5])) | |||||
| output_channel = _make_divisible(multiplier * model_cfgs["cls_ch_squeeze"]) | |||||
| self.features.append(nn.Conv2d(in_channels=_make_divisible(multiplier * self.cfgs[-1][2]), | |||||
| out_channels=output_channel, | |||||
| kernel_size=1, padding=0, stride=1, | |||||
| has_bias=False, pad_mode='pad')) | |||||
| self.features.append(nn.BatchNorm2d(output_channel)) | |||||
| self.features.append(Activation('hswish')) | |||||
| self.features.append(GlobalAvgPooling(keep_dims=True)) | |||||
| self.features.append(nn.Conv2d(in_channels=output_channel, | |||||
| out_channels=model_cfgs['cls_ch_expand'], | |||||
| kernel_size=1, padding=0, stride=1, | |||||
| has_bias=False, pad_mode='pad')) | |||||
| self.features.append(Activation('hswish')) | |||||
| if final_drop > 0: | |||||
| self.features.append((nn.Dropout(final_drop))) | |||||
| # make it nn.CellList | |||||
| self.features = nn.SequentialCell(self.features) | |||||
| self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], | |||||
| out_channels=num_classes, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.squeeze = P.Squeeze(axis=(2, 3)) | |||||
| self._initialize_weights() | |||||
| def construct(self, x): | |||||
| x = self.features(x) | |||||
| x = self.output(x) | |||||
| x = self.squeeze(x) | |||||
| return x | |||||
| def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): | |||||
| mid_planes = exp_ch | |||||
| out_planes = out_channel | |||||
| #num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False): | |||||
| layer = ResUnit(self.inplanes, mid_planes, out_planes, | |||||
| kernel_size, stride=stride, act_type=act_func, use_se=use_se) | |||||
| self.inplanes = out_planes | |||||
| return layer | |||||
| def _initialize_weights(self): | |||||
| """ | |||||
| Initialize weights. | |||||
| Args: | |||||
| Returns: | |||||
| None. | |||||
| Examples: | |||||
| >>> _initialize_weights() | |||||
| """ | |||||
| for _, m in self.cells_and_names(): | |||||
| if isinstance(m, (nn.Conv2d)): | |||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | |||||
| m.weight.data.shape()).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | |||||
| m.gamma.set_parameter_data( | |||||
| Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) | |||||
| m.beta.set_parameter_data( | |||||
| Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) | |||||
| elif isinstance(m, nn.Dense): | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal( | |||||
| 0, 0.01, m.weight.data.shape()).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_parameter_data( | |||||
| Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) | |||||
| def mobilenet_v3(model_name, **kwargs): | |||||
| """ | |||||
| Constructs a MobileNet V2 model | |||||
| """ | |||||
| model_cfgs = { | |||||
| "large": { | |||||
| "cfg": [ | |||||
| # k, exp, c, se, nl, s, | |||||
| [3, 16, 16, False, 'relu', 1], | |||||
| [3, 64, 24, False, 'relu', 2], | |||||
| [3, 72, 24, False, 'relu', 1], | |||||
| [5, 72, 40, True, 'relu', 2], | |||||
| [5, 120, 40, True, 'relu', 1], | |||||
| [5, 120, 40, True, 'relu', 1], | |||||
| [3, 240, 80, False, 'hswish', 2], | |||||
| [3, 200, 80, False, 'hswish', 1], | |||||
| [3, 184, 80, False, 'hswish', 1], | |||||
| [3, 184, 80, False, 'hswish', 1], | |||||
| [3, 480, 112, True, 'hswish', 1], | |||||
| [3, 672, 112, True, 'hswish', 1], | |||||
| [5, 672, 160, True, 'hswish', 2], | |||||
| [5, 960, 160, True, 'hswish', 1], | |||||
| [5, 960, 160, True, 'hswish', 1]], | |||||
| "cls_ch_squeeze": 960, | |||||
| "cls_ch_expand": 1280, | |||||
| }, | |||||
| "small": { | |||||
| "cfg": [ | |||||
| # k, exp, c, se, nl, s, | |||||
| [3, 16, 16, True, 'relu', 2], | |||||
| [3, 72, 24, False, 'relu', 2], | |||||
| [3, 88, 24, False, 'relu', 1], | |||||
| [5, 96, 40, True, 'hswish', 2], | |||||
| [5, 240, 40, True, 'hswish', 1], | |||||
| [5, 240, 40, True, 'hswish', 1], | |||||
| [5, 120, 48, True, 'hswish', 1], | |||||
| [5, 144, 48, True, 'hswish', 1], | |||||
| [5, 288, 96, True, 'hswish', 2], | |||||
| [5, 576, 96, True, 'hswish', 1], | |||||
| [5, 576, 96, True, 'hswish', 1]], | |||||
| "cls_ch_squeeze": 576, | |||||
| "cls_ch_expand": 1280, | |||||
| } | |||||
| } | |||||
| return MobileNetV3(model_cfgs[model_name], **kwargs) | |||||
| mobilenet_v3_large = partial(mobilenet_v3, model_name="large") | |||||
| mobilenet_v3_small = partial(mobilenet_v3, model_name="small") | |||||
| @@ -205,7 +205,7 @@ if __name__ == '__main__': | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | ||||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | keep_checkpoint_max=config_gpu.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint( | ckpt_cb = ModelCheckpoint( | ||||
| prefix="mobilenet", directory=config_gpu.save_checkpoint_path, config=config_ck) | |||||
| prefix="mobilenetV3", directory=config_gpu.save_checkpoint_path, config=config_ck) | |||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| # begine train | # begine train | ||||
| model.train(epoch_size, dataset, callbacks=cb) | model.train(epoch_size, dataset, callbacks=cb) | ||||
| @@ -265,7 +265,7 @@ if __name__ == '__main__': | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size, | ||||
| keep_checkpoint_max=config_ascend.keep_checkpoint_max) | keep_checkpoint_max=config_ascend.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint( | ckpt_cb = ModelCheckpoint( | ||||
| prefix="mobilenet", directory=config_ascend.save_checkpoint_path, config=config_ck) | |||||
| prefix="mobilenetV3", directory=config_ascend.save_checkpoint_path, config=config_ck) | |||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| model.train(epoch_size, dataset, callbacks=cb) | model.train(epoch_size, dataset, callbacks=cb) | ||||
| else: | else: | ||||