From: @yuchaojie Reviewed-by: @kingxian,@c_34 Signed-off-by: @kingxiantags/v1.2.0-rc1
| @@ -20,22 +20,22 @@ Previously the kernel size and pad mode attrs of pooling ops are named "ksize" a | |||
| <td> | |||
| ```python | |||
| >>> from mindspore.ops import operations as P | |||
| >>> import mindspore.ops as ops | |||
| >>> | |||
| >>> avg_pool = P.AvgPool(ksize=2, padding='same') | |||
| >>> max_pool = P.MaxPool(ksize=2, padding='same') | |||
| >>> max_pool_with_argmax = P.MaxPoolWithArgmax(ksize=2, padding='same') | |||
| >>> avg_pool = ops.AvgPool(ksize=2, padding='same') | |||
| >>> max_pool = ops.MaxPool(ksize=2, padding='same') | |||
| >>> max_pool_with_argmax = ops.MaxPoolWithArgmax(ksize=2, padding='same') | |||
| ``` | |||
| </td> | |||
| <td> | |||
| ```python | |||
| >>> from mindspore.ops import operations as P | |||
| >>> import mindspore.ops as ops | |||
| >>> | |||
| >>> avg_pool = P.AvgPool(kernel_size=2, pad_mode='same') | |||
| >>> max_pool = P.MaxPool(kernel_size=2, pad_mode='same') | |||
| >>> max_pool_with_argmax = P.MaxPoolWithArgmax(kernel_size=2, pad_mode='same') | |||
| >>> avg_pool = ops.AvgPool(kernel_size=2, pad_mode='same') | |||
| >>> max_pool = ops.MaxPool(kernel_size=2, pad_mode='same') | |||
| >>> max_pool_with_argmax = ops.MaxPoolWithArgmax(kernel_size=2, pad_mode='same') | |||
| ``` | |||
| </td> | |||
| @@ -18,6 +18,7 @@ | |||
| import math | |||
| import operator | |||
| from functools import reduce, partial | |||
| from mindspore import log as logger | |||
| from mindspore._checkparam import _check_3d_int_or_tuple | |||
| import numpy as np | |||
| from ... import context | |||
| @@ -1476,6 +1477,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||
| dilation=1, | |||
| group=1): | |||
| """Initialize DepthwiseConv2dNative""" | |||
| logger.warning("WARN_DEPRECATED: The usage of DepthwiseConv2dNative is deprecated." | |||
| " Please use nn.Conv2D.") | |||
| self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) | |||
| self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) | |||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name) | |||
| @@ -102,7 +102,7 @@ step1: prepare pretrained model: train a mobilenet_v2 model by mindspore or use | |||
| # The key/cell/module name must as follow, otherwise you need to modify "name_map" function: | |||
| # --mindspore: as the same as mobilenet_v2_key.ckpt | |||
| # --pytorch: same as official pytorch model(e.g., official mobilenet_v2-b0353104.pth) | |||
| python torch_to_ms_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt | |||
| python convert_weight_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt | |||
| ``` | |||
| step2: prepare user rank_table | |||
| @@ -120,7 +120,7 @@ step3: train | |||
| cd scripts; | |||
| # prepare data_path, use symbolic link | |||
| ln -sf [USE_DATA_DIR] dataset | |||
| # check you dir to make sure your datas are in the right path | |||
| # check you dir to make sure your data are in the right path | |||
| ls ./dataset/centerface # data path | |||
| ls ./dataset/centerface/annotations/train.json # annot_path | |||
| ls ./dataset/centerface/images/train/images # img_dir | |||
| @@ -147,7 +147,7 @@ python setup.py install; # used for eval | |||
| cd -; #cd ../../scripts; | |||
| mkdir ./output | |||
| mkdir ./output/centerface | |||
| # check you dir to make sure your datas are in the right path | |||
| # check you dir to make sure your data are in the right path | |||
| ls ./dataset/images/val/images/ # data path | |||
| ls ./dataset/centerface/ground_truth/val.mat # annot_path | |||
| ``` | |||
| @@ -195,7 +195,7 @@ sh eval_all.sh | |||
| │ ├──lr_scheduler.py // learning rate scheduler | |||
| │ ├──mobile_v2.py // modified mobilenet_v2 backbone | |||
| │ ├──utils.py // auxiliary functions for train, to log and preload | |||
| │ ├──var_init.py // weight initilization | |||
| │ ├──var_init.py // weight initialization | |||
| │ ├──convert_weight_mobilenetv2.py // convert pretrained backbone to mindspore | |||
| │ ├──convert_weight.py // CenterFace model convert to mindspore | |||
| └── dependency // third party codes: MIT License | |||
| @@ -414,7 +414,7 @@ After testing, you can find many txt file save the box information and scores, | |||
| open it you can see: | |||
| ```python | |||
| 646.3 189.1 42.1 51.8 0.747 # left top hight weight score | |||
| 646.3 189.1 42.1 51.8 0.747 # left top height weight score | |||
| 157.4 408.6 43.1 54.1 0.667 | |||
| 120.3 212.4 38.7 42.8 0.650 | |||
| ... | |||
| @@ -553,7 +553,7 @@ CenterFace on 3.2K images(The annotation and data format must be the same as wid | |||
| # [Description of Random Situation](#contents) | |||
| In dataset.py, we set the seed inside ```create_dataset``` function. | |||
| In var_init.py, we set seed for weight initilization | |||
| In var_init.py, we set seed for weight initialization | |||
| # [ModelZoo Homepage](#contents) | |||
| @@ -133,11 +133,6 @@ def pt_to_ckpt(pt, ckpt, out_path): | |||
| parameter = state_dict_torch[key] | |||
| parameter = parameter.numpy() | |||
| # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout | |||
| if state_dict_ms[name_relate[key]].data.shape != parameter.shape: | |||
| parameter = parameter.transpose(1, 0, 2, 3) | |||
| print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key) | |||
| param_dict['name'] = name_relate[key] | |||
| param_dict['data'] = Tensor(parameter) | |||
| new_params_list.append(param_dict) | |||
| @@ -158,13 +153,6 @@ def ckpt_to_pt(pt, ckpt, out_path): | |||
| name = name_relate[key] | |||
| parameter = state_dict_ms[name].data | |||
| parameter = parameter.asnumpy() | |||
| if state_dict_ms[name_relate[key]].data.shape != state_dict_torch[key].numpy().shape: | |||
| print('before ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', | |||
| state_dict_torch[key].numpy().shape, 'name=', key) | |||
| parameter = parameter.transpose(1, 0, 2, 3) | |||
| print('after ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', | |||
| state_dict_torch[key].numpy().shape, 'name=', key) | |||
| state_dict[key] = torch.from_numpy(parameter) | |||
| save_model(out_path, epoch=0, model=None, optimizer=None, state_dict=state_dict) | |||
| @@ -120,12 +120,6 @@ def pt_to_ckpt(pt, ckpt, out_ckpt): | |||
| parameter = state_dict_torch[key] | |||
| parameter = parameter.numpy() | |||
| # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout | |||
| if state_dict_ms[name_relate[key]].data.shape != parameter.shape: | |||
| parameter = parameter.transpose(1, 0, 2, 3) | |||
| print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key) | |||
| param_dict['name'] = name_relate[key] | |||
| param_dict['data'] = Tensor(parameter) | |||
| new_params_list.append(param_dict) | |||
| @@ -17,12 +17,10 @@ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import TensorAdd | |||
| from mindspore import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from src.var_init import KaimingNormal | |||
| __all__ = ['MobileNetV2', 'mobilenet_v2', 'DepthWiseConv'] | |||
| __all__ = ['MobileNetV2', 'mobilenet_v2'] | |||
| def _make_divisible(v, divisor, min_value=None): | |||
| """ | |||
| @@ -43,32 +41,6 @@ def _make_divisible(v, divisor, min_value=None): | |||
| new_v += divisor | |||
| return new_v | |||
| class DepthWiseConv(nn.Cell): | |||
| """ | |||
| Depthwise convolution | |||
| """ | |||
| def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | |||
| super(DepthWiseConv, self).__init__() | |||
| self.has_bias = has_bias | |||
| self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size, | |||
| stride=stride, pad_mode=pad_mode, pad=pad) | |||
| self.bias_add = P.BiasAdd() | |||
| weight_shape = [channel_multiplier, in_planes, kernel_size, kernel_size] | |||
| self.weight = Parameter(initializer(KaimingNormal(mode='fan_out'), weight_shape)) | |||
| if has_bias: | |||
| bias_shape = [channel_multiplier * in_planes] | |||
| self.bias = Parameter(initializer('zeros', bias_shape)) | |||
| else: | |||
| self.bias = None | |||
| def construct(self, x): | |||
| output = self.depthwise_conv(x, self.weight) | |||
| if self.has_bias: | |||
| output = self.bias_add(output, self.bias) | |||
| return output | |||
| class ConvBNReLU(nn.Cell): | |||
| """ | |||
| @@ -81,7 +53,8 @@ class ConvBNReLU(nn.Cell): | |||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding, | |||
| has_bias=False) | |||
| else: | |||
| conv = DepthWiseConv(in_planes, kernel_size, stride, pad_mode="pad", pad=padding) | |||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding, | |||
| has_bias=False, group=groups, weight_init=KaimingNormal(mode='fan_out')) | |||
| layers = [conv, nn.BatchNorm2d(out_planes).add_flags_recursive(fp32=True), nn.ReLU6()] #, momentum=0.9 | |||
| self.features = nn.SequentialCell(layers) | |||
| @@ -24,8 +24,6 @@ import numpy as np | |||
| from mindspore.train.serialization import load_checkpoint | |||
| import mindspore.nn as nn | |||
| from src.mobile_v2 import DepthWiseConv | |||
| def load_backbone(net, ckpt_path, args): | |||
| """ | |||
| Load backbone | |||
| @@ -52,7 +50,7 @@ def load_backbone(net, ckpt_path, args): | |||
| for name, cell in net.cells_and_names(): | |||
| if name.startswith(centerface_backbone_prefix): | |||
| name = name.replace(centerface_backbone_prefix, mobilev2_backbone_prefix) | |||
| if isinstance(cell, (nn.Conv2d, nn.Dense, DepthWiseConv)): | |||
| if isinstance(cell, (nn.Conv2d, nn.Dense)): | |||
| name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx) | |||
| mobilev2_weight = '{}.weight'.format(name) | |||
| mobilev2_bias = '{}.bias'.format(name) | |||
| @@ -33,6 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, RunContext | |||
| from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.profiler.profiling import Profiler | |||
| from mindspore.common import set_seed | |||
| from src.utils import get_logger | |||
| from src.utils import AverageMeter | |||
| @@ -47,6 +48,7 @@ from src.config import ConfigCenterface | |||
| from src.centerface import CenterFaceWithLossCell, TrainingWrapper | |||
| from src.dataset import GetDataLoader | |||
| set_seed(1) | |||
| dev_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False, | |||
| device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False) | |||
| @@ -130,7 +132,7 @@ if __name__ == "__main__": | |||
| args.rank = get_rank() | |||
| args.group_size = get_group_size() | |||
| # select for master rank save ckpt or all rank save, compatiable for model parallel | |||
| # select for master rank save ckpt or all rank save, compatible for model parallel | |||
| args.rank_save_ckpt_flag = 0 | |||
| if args.is_save_on_master: | |||
| if args.rank == 0: | |||
| @@ -20,10 +20,8 @@ from copy import deepcopy | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context, ms_function | |||
| from mindspore.common.initializer import (Normal, One, Uniform, Zero, | |||
| initializer) | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore import ms_function | |||
| from mindspore.common.initializer import (Normal, One, Uniform, Zero) | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.composite import clip_by_value | |||
| @@ -224,13 +222,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): | |||
| # activation fn | |||
| key = op[0] | |||
| v = op[1:] | |||
| if v == 're': | |||
| print('not support') | |||
| elif v == 'r6': | |||
| print('not support') | |||
| elif v == 'hs': | |||
| print('not support') | |||
| elif v == 'sw': | |||
| if v in ('re', 'r6', 'hs', 'sw'): | |||
| print('not support') | |||
| else: | |||
| continue | |||
| @@ -459,28 +451,6 @@ class BlockBuilder(nn.Cell): | |||
| return self.layer(x) | |||
| class DepthWiseConv(nn.Cell): | |||
| def __init__(self, in_planes, kernel_size, stride): | |||
| super(DepthWiseConv, self).__init__() | |||
| platform = context.get_context("device_target") | |||
| weight_shape = [1, kernel_size, in_planes] | |||
| weight_init = _initialize_weight_goog(shape=weight_shape) | |||
| if platform == "GPU": | |||
| self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size, | |||
| stride=stride, pad_mode="same", group=in_planes) | |||
| self.weight = Parameter(initializer( | |||
| weight_init, [in_planes * 1, 1, kernel_size, kernel_size])) | |||
| else: | |||
| self.depthwise_conv = P.DepthwiseConv2dNative( | |||
| channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',) | |||
| self.weight = Parameter(initializer( | |||
| weight_init, [1, in_planes, kernel_size, kernel_size])) | |||
| def construct(self, x): | |||
| x = self.depthwise_conv(x, self.weight) | |||
| return x | |||
| class DropConnect(nn.Cell): | |||
| def __init__(self, drop_connect_rate=0., seed0=0, seed1=0): | |||
| super(DropConnect, self).__init__() | |||
| @@ -540,7 +510,9 @@ class DepthwiseSeparableConv(nn.Cell): | |||
| self.has_pw_act = pw_act | |||
| self.act_fn = act_fn | |||
| self.drop_connect_rate = drop_connect_rate | |||
| self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) | |||
| self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="same", | |||
| has_bias=False, group=in_chs, | |||
| weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs])) | |||
| self.bn1 = _fused_bn(in_chs, **bn_args) | |||
| # | |||
| @@ -595,7 +567,9 @@ class InvertedResidual(nn.Cell): | |||
| if self.shuffle_type is not None and isinstance(exp_kernel_size, list): | |||
| self.shuffle = None | |||
| self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) | |||
| self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="same", | |||
| has_bias=False, group=mid_chs, | |||
| weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs])) | |||
| self.bn2 = _fused_bn(mid_chs, **bn_args) | |||
| if self.has_se: | |||
| @@ -20,13 +20,12 @@ import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Parameter, context, Tensor | |||
| from mindspore import context, Tensor | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.initializer import initializer | |||
| def _make_divisible(x, divisor=4): | |||
| @@ -44,8 +43,8 @@ def _bn(channel): | |||
| def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): | |||
| depthwise_conv = DepthwiseConv( | |||
| in_channel, kernel_size, stride, pad_mode='same', pad=pad) | |||
| depthwise_conv = nn.Conv2d(in_channel, in_channel, kernel_size, stride, pad_mode='same', padding=pad, | |||
| has_bias=False, group=in_channel, weight_init='ones') | |||
| conv = _conv2d(in_channel, out_channel, kernel_size=1) | |||
| return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) | |||
| @@ -75,8 +74,8 @@ class ConvBNReLU(nn.Cell): | |||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', | |||
| padding=padding) | |||
| else: | |||
| conv = DepthwiseConv(in_planes, kernel_size, | |||
| stride, pad_mode='same', pad=padding) | |||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', padding=padding, | |||
| has_bias=False, group=groups, weight_init='ones') | |||
| layers = [conv, _bn(out_planes)] | |||
| if use_act: | |||
| layers.append(Activation(act_type)) | |||
| @@ -87,52 +86,6 @@ class ConvBNReLU(nn.Cell): | |||
| return output | |||
| class DepthwiseConv(nn.Cell): | |||
| """ | |||
| Depthwise Convolution warpper definition. | |||
| Args: | |||
| in_planes (int): Input channel. | |||
| kernel_size (int): Input kernel size. | |||
| stride (int): Stride size. | |||
| pad_mode (str): pad mode in (pad, same, valid) | |||
| channel_multiplier (int): Output channel multiplier | |||
| has_bias (bool): has bias or not | |||
| Returns: | |||
| Tensor, output tensor. | |||
| Examples: | |||
| >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) | |||
| """ | |||
| def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | |||
| super(DepthwiseConv, self).__init__() | |||
| self.has_bias = has_bias | |||
| self.in_channels = in_planes | |||
| self.channel_multiplier = channel_multiplier | |||
| self.out_channels = in_planes * channel_multiplier | |||
| self.kernel_size = (kernel_size, kernel_size) | |||
| self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, | |||
| kernel_size=self.kernel_size, | |||
| stride=stride, pad_mode=pad_mode, pad=pad) | |||
| self.bias_add = P.BiasAdd() | |||
| weight_shape = [channel_multiplier, in_planes, *self.kernel_size] | |||
| self.weight = Parameter(initializer('ones', weight_shape), name="weight") | |||
| if has_bias: | |||
| bias_shape = [channel_multiplier * in_planes] | |||
| self.bias = Parameter(initializer('zeros', bias_shape), name="bias") | |||
| else: | |||
| self.bias = None | |||
| def construct(self, x): | |||
| output = self.depthwise_conv(x, self.weight) | |||
| if self.has_bias: | |||
| output = self.bias_add(output, self.bias) | |||
| return output | |||
| class MyHSigmoid(nn.Cell): | |||
| def __init__(self): | |||
| super(MyHSigmoid, self).__init__() | |||
| @@ -20,9 +20,8 @@ from copy import deepcopy | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform | |||
| from mindspore import context, ms_function | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import Normal, Zero, One, Uniform | |||
| from mindspore import ms_function | |||
| from mindspore import Tensor | |||
| # Imagenet constant values | |||
| @@ -244,13 +243,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): | |||
| # activation fn | |||
| key = op[0] | |||
| v = op[1:] | |||
| if v == 're': | |||
| print('not support') | |||
| elif v == 'r6': | |||
| print('not support') | |||
| elif v == 'hs': | |||
| print('not support') | |||
| elif v == 'sw': | |||
| if v in ('re', 'r6', 'hs', 'sw'): | |||
| print('not support') | |||
| else: | |||
| continue | |||
| @@ -485,40 +478,6 @@ class BlockBuilder(nn.Cell): | |||
| return self.layer(x) | |||
| class DepthWiseConv(nn.Cell): | |||
| """depth-wise convolution""" | |||
| def __init__(self, in_planes, kernel_size, stride): | |||
| super(DepthWiseConv, self).__init__() | |||
| platform = context.get_context("device_target") | |||
| weight_shape = [1, kernel_size, in_planes] | |||
| weight_init = _initialize_weight_goog(shape=weight_shape) | |||
| if platform == "GPU": | |||
| self.depthwise_conv = P.Conv2D(out_channel=in_planes*1, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| pad=int(kernel_size/2), | |||
| pad_mode="pad", | |||
| group=in_planes) | |||
| self.weight = Parameter(initializer(weight_init, | |||
| [in_planes*1, 1, kernel_size, kernel_size])) | |||
| else: | |||
| self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=kernel_size, | |||
| stride=stride, pad_mode='pad', | |||
| pad=int(kernel_size/2)) | |||
| self.weight = Parameter(initializer(weight_init, | |||
| [1, in_planes, kernel_size, kernel_size])) | |||
| def construct(self, x): | |||
| x = self.depthwise_conv(x, self.weight) | |||
| return x | |||
| class DropConnect(nn.Cell): | |||
| """drop connect implementation""" | |||
| @@ -584,7 +543,9 @@ class DepthwiseSeparableConv(nn.Cell): | |||
| self.has_pw_act = pw_act | |||
| self.act_fn = Swish() | |||
| self.drop_connect_rate = drop_connect_rate | |||
| self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) | |||
| self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="pad", | |||
| padding=int(dw_kernel_size/2), has_bias=False, group=in_chs, | |||
| weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs])) | |||
| self.bn1 = _fused_bn(in_chs, **bn_args) | |||
| if self.has_se: | |||
| @@ -640,7 +601,9 @@ class InvertedResidual(nn.Cell): | |||
| if self.shuffle_type is not None and isinstance(exp_kernel_size, list): | |||
| self.shuffle = None | |||
| self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) | |||
| self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="pad", | |||
| padding=int(dw_kernel_size/2), has_bias=False, group=mid_chs, | |||
| weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs])) | |||
| self.bn2 = _fused_bn(mid_chs, **bn_args) | |||
| if self.has_se: | |||
| @@ -16,34 +16,6 @@ | |||
| import math | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore import Parameter | |||
| class DepthWiseConv(nn.Cell): | |||
| '''Build DepthWise conv.''' | |||
| def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | |||
| super(DepthWiseConv, self).__init__() | |||
| self.has_bias = has_bias | |||
| self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size, | |||
| stride=stride, pad_mode=pad_mode, pad=pad) | |||
| self.bias_add = P.BiasAdd() | |||
| weight_shape = [channel_multiplier, in_planes, kernel_size[0], kernel_size[1]] | |||
| self.weight = Parameter(initializer('ones', weight_shape)) | |||
| if has_bias: | |||
| bias_shape = [channel_multiplier * in_planes] | |||
| self.bias = Parameter(initializer('zeros', bias_shape)) | |||
| else: | |||
| self.bias = None | |||
| def construct(self, x): | |||
| output = self.depthwise_conv(x, self.weight) | |||
| if self.has_bias: | |||
| output = self.bias_add(output, self.bias) | |||
| return output | |||
| class DSCNN(nn.Cell): | |||
| @@ -85,8 +57,9 @@ class DSCNN(nn.Cell): | |||
| seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98)) | |||
| in_channel = conv_feat[layer_no] | |||
| else: | |||
| seq_cell.append(DepthWiseConv(in_planes=in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]), | |||
| stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', pad=0)) | |||
| seq_cell.append(nn.Conv2d(in_channel, in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]), | |||
| stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', | |||
| has_bias=False, group=in_channel, weight_init='ones')) | |||
| seq_cell.append(nn.BatchNorm2d(num_features=in_channel, momentum=0.98)) | |||
| seq_cell.append(nn.ReLU()) | |||
| seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no], kernel_size=(1, 1), | |||