| @@ -24,7 +24,8 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from .mobilenet import InvertedResidual, ConvBNReLU | |||||
| from mindspore.ops.operations import TensorAdd | |||||
| from mindspore import Parameter | |||||
| def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): | def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): | ||||
| @@ -45,6 +46,129 @@ def _make_divisible(v, divisor, min_value=None): | |||||
| return new_v | return new_v | ||||
| class DepthwiseConv(nn.Cell): | |||||
| """ | |||||
| Depthwise Convolution warpper definition. | |||||
| Args: | |||||
| in_planes (int): Input channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size. | |||||
| pad_mode (str): pad mode in (pad, same, valid) | |||||
| channel_multiplier (int): Output channel multiplier | |||||
| has_bias (bool): has bias or not | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) | |||||
| """ | |||||
| def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): | |||||
| super(DepthwiseConv, self).__init__() | |||||
| self.has_bias = has_bias | |||||
| self.in_channels = in_planes | |||||
| self.channel_multiplier = channel_multiplier | |||||
| self.out_channels = in_planes * channel_multiplier | |||||
| self.kernel_size = (kernel_size, kernel_size) | |||||
| self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, | |||||
| kernel_size=self.kernel_size, | |||||
| stride=stride, pad_mode=pad_mode, pad=pad) | |||||
| self.bias_add = P.BiasAdd() | |||||
| weight_shape = [channel_multiplier, in_planes, *self.kernel_size] | |||||
| self.weight = Parameter(initializer('ones', weight_shape), name='weight') | |||||
| if has_bias: | |||||
| bias_shape = [channel_multiplier * in_planes] | |||||
| self.bias = Parameter(initializer('zeros', bias_shape), name='bias') | |||||
| else: | |||||
| self.bias = None | |||||
| def construct(self, x): | |||||
| output = self.depthwise_conv(x, self.weight) | |||||
| if self.has_bias: | |||||
| output = self.bias_add(output, self.bias) | |||||
| return output | |||||
| class ConvBNReLU(nn.Cell): | |||||
| """ | |||||
| Convolution/Depthwise fused with Batchnorm and ReLU block definition. | |||||
| Args: | |||||
| in_planes (int): Input channel. | |||||
| out_planes (int): Output channel. | |||||
| kernel_size (int): Input kernel size. | |||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||||
| groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | |||||
| """ | |||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||||
| super(ConvBNReLU, self).__init__() | |||||
| padding = (kernel_size - 1) // 2 | |||||
| if groups == 1: | |||||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', | |||||
| padding=padding) | |||||
| else: | |||||
| conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) | |||||
| layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] | |||||
| self.features = nn.SequentialCell(layers) | |||||
| def construct(self, x): | |||||
| output = self.features(x) | |||||
| return output | |||||
| class InvertedResidual(nn.Cell): | |||||
| """ | |||||
| Mobilenetv2 residual block definition. | |||||
| Args: | |||||
| inp (int): Input channel. | |||||
| oup (int): Output channel. | |||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||||
| expand_ratio (int): expand ration of input channel | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ResidualBlock(3, 256, 1, 1) | |||||
| """ | |||||
| def __init__(self, inp, oup, stride, expand_ratio): | |||||
| super(InvertedResidual, self).__init__() | |||||
| assert stride in [1, 2] | |||||
| hidden_dim = int(round(inp * expand_ratio)) | |||||
| self.use_res_connect = stride == 1 and inp == oup | |||||
| layers = [] | |||||
| if expand_ratio != 1: | |||||
| layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | |||||
| layers.extend([ | |||||
| # dw | |||||
| ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | |||||
| # pw-linear | |||||
| nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), | |||||
| nn.BatchNorm2d(oup), | |||||
| ]) | |||||
| self.conv = nn.SequentialCell(layers) | |||||
| self.add = TensorAdd() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, x): | |||||
| identity = x | |||||
| x = self.conv(x) | |||||
| if self.use_res_connect: | |||||
| return self.add(identity, x) | |||||
| return x | |||||
| class FlattenConcat(nn.Cell): | class FlattenConcat(nn.Cell): | ||||
| """ | """ | ||||
| Concatenate predictions into a single tensor. | Concatenate predictions into a single tensor. | ||||
| @@ -57,20 +181,17 @@ class FlattenConcat(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, config): | def __init__(self, config): | ||||
| super(FlattenConcat, self).__init__() | super(FlattenConcat, self).__init__() | ||||
| self.sizes = config.FEATURE_SIZE | |||||
| self.length = len(self.sizes) | |||||
| self.num_default = config.NUM_DEFAULT | |||||
| self.concat = P.Concat(axis=-1) | |||||
| self.num_ssd_boxes = config.NUM_SSD_BOXES | |||||
| self.concat = P.Concat(axis=1) | |||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| def construct(self, x): | |||||
| def construct(self, inputs): | |||||
| output = () | output = () | ||||
| for i in range(self.length): | |||||
| shape = F.shape(x[i]) | |||||
| mid_shape = (shape[0], -1, self.num_default[i], self.sizes[i], self.sizes[i]) | |||||
| final_shape = (shape[0], -1, self.num_default[i] * self.sizes[i] * self.sizes[i]) | |||||
| output += (F.reshape(F.reshape(x[i], mid_shape), final_shape),) | |||||
| batch_size = F.shape(inputs[0])[0] | |||||
| for x in inputs: | |||||
| x = self.transpose(x, (0, 2, 3, 1)) | |||||
| output += (F.reshape(x, (batch_size, -1)),) | |||||
| res = self.concat(output) | res = self.concat(output) | ||||
| return self.transpose(res, (0, 2, 1)) | |||||
| return F.reshape(res, (batch_size, self.num_ssd_boxes, -1)) | |||||
| class MultiBox(nn.Cell): | class MultiBox(nn.Cell): | ||||
| @@ -145,7 +266,6 @@ class SSD300(nn.Cell): | |||||
| if not is_training: | if not is_training: | ||||
| self.softmax = P.Softmax() | self.softmax = P.Softmax() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| layer_out_13, output = self.backbone(x) | layer_out_13, output = self.backbone(x) | ||||
| multi_feature = (layer_out_13, output) | multi_feature = (layer_out_13, output) | ||||