| @@ -0,0 +1,23 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Init DeepLabv3.""" | |||||
| from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50 | |||||
| from .backbone import * | |||||
| __all__ = [ | |||||
| "ASPP", "DeepLabV3", "deeplabv3_resnet50" | |||||
| ] | |||||
| __all__.extend(backbone.__all__) | |||||
| @@ -0,0 +1,21 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Init backbone.""" | |||||
| from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \ | |||||
| RootBlockBeta, resnet50_dl | |||||
| __all__ = [ | |||||
| "Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", "resnet50_dl" | |||||
| ] | |||||
| @@ -0,0 +1,577 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ResNet based DeepLab.""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore._checkparam import twice | |||||
| from mindspore.common.parameter import Parameter | |||||
| def _conv_bn_relu(in_channel, | |||||
| out_channel, | |||||
| ksize, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| pad_mode="pad", | |||||
| use_batch_statistics=False): | |||||
| """Get a conv2d -> batchnorm -> relu layer""" | |||||
| return nn.SequentialCell( | |||||
| [nn.Conv2d(in_channel, | |||||
| out_channel, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| pad_mode=pad_mode), | |||||
| nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics), | |||||
| nn.ReLU()] | |||||
| ) | |||||
| def _deep_conv_bn_relu(in_channel, | |||||
| channel_multiplier, | |||||
| ksize, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| pad_mode="pad", | |||||
| use_batch_statistics=False): | |||||
| """Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" | |||||
| return nn.SequentialCell( | |||||
| [DepthwiseConv2dNative(in_channel, | |||||
| channel_multiplier, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| pad_mode=pad_mode), | |||||
| nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics), | |||||
| nn.ReLU()] | |||||
| ) | |||||
| def _stob_deep_conv_btos_bn_relu(in_channel, | |||||
| channel_multiplier, | |||||
| ksize, | |||||
| space_to_batch_block_shape, | |||||
| batch_to_space_block_shape, | |||||
| paddings, | |||||
| crops, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| pad_mode="pad", | |||||
| use_batch_statistics=False): | |||||
| """Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" | |||||
| return nn.SequentialCell( | |||||
| [SpaceToBatch(space_to_batch_block_shape, paddings), | |||||
| DepthwiseConv2dNative(in_channel, | |||||
| channel_multiplier, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| pad_mode=pad_mode), | |||||
| BatchToSpace(batch_to_space_block_shape, crops), | |||||
| nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics), | |||||
| nn.ReLU()] | |||||
| ) | |||||
| def _stob_conv_btos_bn_relu(in_channel, | |||||
| out_channel, | |||||
| ksize, | |||||
| space_to_batch_block_shape, | |||||
| batch_to_space_block_shape, | |||||
| paddings, | |||||
| crops, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| pad_mode="pad", | |||||
| use_batch_statistics=False): | |||||
| """Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" | |||||
| return nn.SequentialCell([SpaceToBatch(space_to_batch_block_shape, paddings), | |||||
| nn.Conv2d(in_channel, | |||||
| out_channel, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| pad_mode=pad_mode), | |||||
| BatchToSpace(batch_to_space_block_shape, crops), | |||||
| nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics), | |||||
| nn.ReLU()] | |||||
| ) | |||||
| def _make_layer(block, | |||||
| in_channels, | |||||
| out_channels, | |||||
| num_blocks, | |||||
| stride=1, | |||||
| rate=1, | |||||
| multi_grads=None, | |||||
| output_stride=None, | |||||
| g_current_stride=2, | |||||
| g_rate=1): | |||||
| """Make layer for DeepLab-ResNet network.""" | |||||
| if multi_grads is None: | |||||
| multi_grads = [1] * num_blocks | |||||
| # (stride == 2, num_blocks == 4 --> strides == [1, 1, 1, 2]) | |||||
| strides = [1] * (num_blocks - 1) + [stride] | |||||
| blocks = [] | |||||
| if output_stride is not None: | |||||
| if output_stride % 4 != 0: | |||||
| raise ValueError('The output_stride needs to be a multiple of 4.') | |||||
| output_stride //= 4 | |||||
| for i_stride, _ in enumerate(strides): | |||||
| if output_stride is not None and g_current_stride > output_stride: | |||||
| raise ValueError('The target output_stride cannot be reached.') | |||||
| if output_stride is not None and g_current_stride == output_stride: | |||||
| b_rate = g_rate | |||||
| b_stride = 1 | |||||
| g_rate *= strides[i_stride] | |||||
| else: | |||||
| b_rate = rate | |||||
| b_stride = strides[i_stride] | |||||
| g_current_stride *= strides[i_stride] | |||||
| blocks.append(block(in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| stride=b_stride, | |||||
| rate=b_rate, | |||||
| multi_grad=multi_grads[i_stride])) | |||||
| in_channels = out_channels | |||||
| layer = nn.SequentialCell(blocks) | |||||
| return layer, g_current_stride, g_rate | |||||
| class Subsample(nn.Cell): | |||||
| """ | |||||
| Subsample for DeepLab-ResNet. | |||||
| Args: | |||||
| factor (int): Sample factor. | |||||
| Returns: | |||||
| Tensor, the sub sampled tensor. | |||||
| Examples: | |||||
| >>> Subsample(2) | |||||
| """ | |||||
| def __init__(self, factor): | |||||
| super(Subsample, self).__init__() | |||||
| self.factor = factor | |||||
| self.pool = nn.MaxPool2d(kernel_size=1, | |||||
| stride=factor) | |||||
| def construct(self, x): | |||||
| if self.factor == 1: | |||||
| return x | |||||
| return self.pool(x) | |||||
| class SpaceToBatch(nn.Cell): | |||||
| def __init__(self, block_shape, paddings): | |||||
| super(SpaceToBatch, self).__init__() | |||||
| self.space_to_batch = P.SpaceToBatch(block_shape, paddings) | |||||
| self.bs = block_shape | |||||
| self.pd = paddings | |||||
| def construct(self, x): | |||||
| return self.space_to_batch(x) | |||||
| class BatchToSpace(nn.Cell): | |||||
| def __init__(self, block_shape, crops): | |||||
| super(BatchToSpace, self).__init__() | |||||
| self.batch_to_space = P.BatchToSpace(block_shape, crops) | |||||
| self.bs = block_shape | |||||
| self.cr = crops | |||||
| def construct(self, x): | |||||
| return self.batch_to_space(x) | |||||
| class _DepthwiseConv2dNative(nn.Cell): | |||||
| """Depthwise Conv2D Cell.""" | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| channel_multiplier, | |||||
| kernel_size, | |||||
| stride, | |||||
| pad_mode, | |||||
| padding, | |||||
| dilation, | |||||
| group, | |||||
| weight_init): | |||||
| super(_DepthwiseConv2dNative, self).__init__() | |||||
| self.in_channels = in_channels | |||||
| self.channel_multiplier = channel_multiplier | |||||
| self.kernel_size = kernel_size | |||||
| self.stride = stride | |||||
| self.pad_mode = pad_mode | |||||
| self.padding = padding | |||||
| self.dilation = dilation | |||||
| self.group = group | |||||
| if not (isinstance(in_channels, int) and in_channels > 0): | |||||
| raise ValueError('Attr \'in_channels\' of \'DepthwiseConv2D\' Op passed ' | |||||
| + str(in_channels) + ', should be a int and greater than 0.') | |||||
| if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \ | |||||
| (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \ | |||||
| kernel_size[0] < 1 or kernel_size[1] < 1: | |||||
| raise ValueError('Attr \'kernel_size\' of \'DepthwiseConv2D\' Op passed ' | |||||
| + str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.') | |||||
| self.weight = Parameter(initializer(weight_init, [1, in_channels // group, *kernel_size]), | |||||
| name='weight') | |||||
| def construct(self, *inputs): | |||||
| """Must be overridden by all subclasses.""" | |||||
| raise NotImplementedError | |||||
| class DepthwiseConv2dNative(_DepthwiseConv2dNative): | |||||
| """Depthwise Conv2D Cell.""" | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| channel_multiplier, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| pad_mode='same', | |||||
| padding=0, | |||||
| dilation=1, | |||||
| group=1, | |||||
| weight_init='normal'): | |||||
| kernel_size = twice(kernel_size) | |||||
| super(DepthwiseConv2dNative, self).__init__( | |||||
| in_channels, | |||||
| channel_multiplier, | |||||
| kernel_size, | |||||
| stride, | |||||
| pad_mode, | |||||
| padding, | |||||
| dilation, | |||||
| group, | |||||
| weight_init) | |||||
| self.depthwise_conv2d_native = P.DepthwiseConv2dNative(channel_multiplier=self.channel_multiplier, | |||||
| kernel_size=self.kernel_size, | |||||
| mode=3, | |||||
| pad_mode=self.pad_mode, | |||||
| pad=self.padding, | |||||
| stride=self.stride, | |||||
| dilation=self.dilation, | |||||
| group=self.group) | |||||
| def set_strategy(self, strategy): | |||||
| self.depthwise_conv2d_native.set_strategy(strategy) | |||||
| return self | |||||
| def construct(self, x): | |||||
| return self.depthwise_conv2d_native(x, self.weight) | |||||
| class BottleneckV1(nn.Cell): | |||||
| """ | |||||
| ResNet V1 BottleneckV1 block definition. | |||||
| Args: | |||||
| in_channels (int): Input channel. | |||||
| out_channels (int): Output channel. | |||||
| stride (int): Stride size for the initial convolutional layer. Default: 1. | |||||
| rate (int): Rate for convolution. Default: 1. | |||||
| multi_grad (int): Employ a rate within network. Default: 1. | |||||
| Returns: | |||||
| Tensor, the ResNet unit's output. | |||||
| Examples: | |||||
| >>> BottleneckV1(3,256,stride=2) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| stride=1, | |||||
| use_batch_statistics=False, | |||||
| use_batch_to_stob_and_btos=False): | |||||
| super(BottleneckV1, self).__init__() | |||||
| expansion = 4 | |||||
| mid_channels = out_channels // expansion | |||||
| self.conv_bn1 = _conv_bn_relu(in_channels, | |||||
| mid_channels, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.conv_bn2 = _conv_bn_relu(mid_channels, | |||||
| mid_channels, | |||||
| ksize=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| dilation=1, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| if use_batch_to_stob_and_btos: | |||||
| self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels, | |||||
| mid_channels, | |||||
| ksize=3, | |||||
| stride=stride, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| space_to_batch_block_shape=2, | |||||
| batch_to_space_block_shape=2, | |||||
| paddings=[[2, 3], [2, 3]], | |||||
| crops=[[0, 1], [0, 1]], | |||||
| pad_mode="valid", | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.conv3 = nn.Conv2d(mid_channels, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| stride=1) | |||||
| self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| if in_channels != out_channels: | |||||
| conv = nn.Conv2d(in_channels, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| stride=stride) | |||||
| bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| self.downsample = nn.SequentialCell([conv, bn]) | |||||
| else: | |||||
| self.downsample = Subsample(stride) | |||||
| self.add = P.TensorAdd() | |||||
| self.relu = nn.ReLU() | |||||
| self.Reshape = P.Reshape() | |||||
| def construct(self, x): | |||||
| out = self.conv_bn1(x) | |||||
| out = self.conv_bn2(out) | |||||
| out = self.bn3(self.conv3(out)) | |||||
| out = self.add(out, self.downsample(x)) | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class BottleneckV2(nn.Cell): | |||||
| """ | |||||
| ResNet V2 Bottleneck variance V2 block definition. | |||||
| Args: | |||||
| in_channels (int): Input channel. | |||||
| out_channels (int): Output channel. | |||||
| stride (int): Stride size for the initial convolutional layer. Default: 1. | |||||
| Returns: | |||||
| Tensor, the ResNet unit's output. | |||||
| Examples: | |||||
| >>> BottleneckV2(3,256,stride=2) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| stride=1, | |||||
| use_batch_statistics=False, | |||||
| use_batch_to_stob_and_btos=False, | |||||
| dilation=1): | |||||
| super(BottleneckV2, self).__init__() | |||||
| expansion = 4 | |||||
| mid_channels = out_channels // expansion | |||||
| self.conv_bn1 = _conv_bn_relu(in_channels, | |||||
| mid_channels, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.conv_bn2 = _conv_bn_relu(mid_channels, | |||||
| mid_channels, | |||||
| ksize=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| dilation=dilation, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| if use_batch_to_stob_and_btos: | |||||
| self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels, | |||||
| mid_channels, | |||||
| ksize=3, | |||||
| stride=stride, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| space_to_batch_block_shape=2, | |||||
| batch_to_space_block_shape=2, | |||||
| paddings=[[2, 3], [2, 3]], | |||||
| crops=[[0, 1], [0, 1]], | |||||
| pad_mode="valid", | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.conv3 = nn.Conv2d(mid_channels, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| stride=1) | |||||
| self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| if in_channels != out_channels: | |||||
| conv = nn.Conv2d(in_channels, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| stride=stride) | |||||
| bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| self.downsample = nn.SequentialCell([conv, bn]) | |||||
| else: | |||||
| self.downsample = Subsample(stride) | |||||
| self.add = P.TensorAdd() | |||||
| self.relu = nn.ReLU() | |||||
| def construct(self, x): | |||||
| out = self.conv_bn1(x) | |||||
| out = self.conv_bn2(out) | |||||
| out = self.bn3(self.conv3(out)) | |||||
| out = self.add(out, x) | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class BottleneckV3(nn.Cell): | |||||
| """ | |||||
| ResNet V1 Bottleneck variance V1 block definition. | |||||
| Args: | |||||
| in_channels (int): Input channel. | |||||
| out_channels (int): Output channel. | |||||
| stride (int): Stride size for the initial convolutional layer. Default: 1. | |||||
| Returns: | |||||
| Tensor, the ResNet unit's output. | |||||
| Examples: | |||||
| >>> BottleneckV3(3,256,stride=2) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| stride=1, | |||||
| use_batch_statistics=False): | |||||
| super(BottleneckV3, self).__init__() | |||||
| expansion = 4 | |||||
| mid_channels = out_channels // expansion | |||||
| self.conv_bn1 = _conv_bn_relu(in_channels, | |||||
| mid_channels, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.conv_bn2 = _conv_bn_relu(mid_channels, | |||||
| mid_channels, | |||||
| ksize=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| dilation=1, | |||||
| use_batch_statistics=use_batch_statistics) | |||||
| self.conv3 = nn.Conv2d(mid_channels, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| stride=1) | |||||
| self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| if in_channels != out_channels: | |||||
| conv = nn.Conv2d(in_channels, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| stride=stride) | |||||
| bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) | |||||
| self.downsample = nn.SequentialCell([conv, bn]) | |||||
| else: | |||||
| self.downsample = Subsample(stride) | |||||
| self.downsample = Subsample(stride) | |||||
| self.add = P.TensorAdd() | |||||
| self.relu = nn.ReLU() | |||||
| def construct(self, x): | |||||
| out = self.conv_bn1(x) | |||||
| out = self.conv_bn2(out) | |||||
| out = self.bn3(self.conv3(out)) | |||||
| out = self.add(out, self.downsample(x)) | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class ResNetV1(nn.Cell): | |||||
| """ | |||||
| ResNet V1 for DeepLab. | |||||
| Args: | |||||
| Returns: | |||||
| Tuple, output tensor tuple, (c2,c5). | |||||
| Examples: | |||||
| >>> ResNetV1(False) | |||||
| """ | |||||
| def __init__(self, fine_tune_batch_norm=False): | |||||
| super(ResNetV1, self).__init__() | |||||
| self.layer_root = nn.SequentialCell( | |||||
| [RootBlockBeta(fine_tune_batch_norm), | |||||
| nn.MaxPool2d(kernel_size=(3, 3), | |||||
| stride=(2, 2), | |||||
| pad_mode='same')]) | |||||
| self.layer1_1 = BottleneckV1(128, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer1_2 = BottleneckV2(256, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer1_3 = BottleneckV3(256, 256, stride=2, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer2_1 = BottleneckV1(256, 512, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer2_2 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer2_3 = BottleneckV2(512, 512, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer2_4 = BottleneckV3(512, 512, stride=2, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer3_1 = BottleneckV1(512, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer3_2 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer3_3 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer3_4 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.layer4_3 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| def construct(self, x): | |||||
| x = self.layer_root(x) | |||||
| x = self.layer1_1(x) | |||||
| c2 = self.layer1_2(x) | |||||
| x = self.layer1_3(c2) | |||||
| x = self.layer2_1(x) | |||||
| x = self.layer2_2(x) | |||||
| x = self.layer2_3(x) | |||||
| x = self.layer2_4(x) | |||||
| x = self.layer3_1(x) | |||||
| x = self.layer3_2(x) | |||||
| x = self.layer3_3(x) | |||||
| x = self.layer3_4(x) | |||||
| x = self.layer3_5(x) | |||||
| x = self.layer3_6(x) | |||||
| x = self.layer4_1(x) | |||||
| x = self.layer4_2(x) | |||||
| c5 = self.layer4_3(x) | |||||
| return c2, c5 | |||||
| class RootBlockBeta(nn.Cell): | |||||
| """ | |||||
| ResNet V1 beta root block definition. | |||||
| Returns: | |||||
| Tensor, the block unit's output. | |||||
| Examples: | |||||
| >>> RootBlockBeta() | |||||
| """ | |||||
| def __init__(self, fine_tune_batch_norm=False): | |||||
| super(RootBlockBeta, self).__init__() | |||||
| self.conv1 = _conv_bn_relu(3, 64, ksize=3, stride=2, padding=0, pad_mode="valid", | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.conv2 = _conv_bn_relu(64, 64, ksize=3, stride=1, padding=0, pad_mode="same", | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.conv3 = _conv_bn_relu(64, 128, ksize=3, stride=1, padding=0, pad_mode="same", | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| def construct(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.conv2(x) | |||||
| x = self.conv3(x) | |||||
| return x | |||||
| def resnet50_dl(fine_tune_batch_norm=False): | |||||
| return ResNetV1(fine_tune_batch_norm) | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py and evaluation.py | |||||
| """ | |||||
| from easydict import EasyDict as ed | |||||
| config = ed({ | |||||
| "learning_rate": 0.0014, | |||||
| "weight_decay": 0.00005, | |||||
| "momentum": 0.97, | |||||
| "crop_size": 513, | |||||
| "eval_scales": [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], | |||||
| "atrous_rates": None, | |||||
| "image_pyramid": None, | |||||
| "output_stride": 16, | |||||
| "fine_tune_batch_norm": False, | |||||
| "ignore_label": 255, | |||||
| "decoder_output_stride": None, | |||||
| "seg_num_classes": 21, | |||||
| "epoch_size": 6, | |||||
| "batch_size": 2, | |||||
| "enable_save_ckpt": True, | |||||
| "save_checkpoint_steps": 10000, | |||||
| "save_checkpoint_num": 1 | |||||
| }) | |||||
| @@ -0,0 +1,457 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """DeepLabv3.""" | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ | |||||
| DepthwiseConv2dNative, SpaceToBatch, BatchToSpace | |||||
| class ASPPSampleBlock(nn.Cell): | |||||
| """ASPP sample block.""" | |||||
| def __init__(self, feature_shape, scale_size, output_stride): | |||||
| super(ASPPSampleBlock, self).__init__() | |||||
| sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1 | |||||
| sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1 | |||||
| self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) | |||||
| def construct(self, x): | |||||
| return self.sample(x) | |||||
| class ASPP(nn.Cell): | |||||
| """ | |||||
| ASPP model for DeepLabv3. | |||||
| Args: | |||||
| channel (int): Input channel. | |||||
| depth (int): Output channel. | |||||
| feature_shape (list): The shape of feature,[h,w]. | |||||
| scale_sizes (list): Input scales for multi-scale feature extraction. | |||||
| atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. | |||||
| output_stride (int): 'The ratio of input to output spatial resolution.' | |||||
| fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ASPP(channel=2048,256,[14,14],[1],[6],16) | |||||
| """ | |||||
| def __init__(self, channel, depth, feature_shape, scale_sizes, | |||||
| atrous_rates, output_stride, fine_tune_batch_norm=False): | |||||
| super(ASPP, self).__init__() | |||||
| self.aspp0 = _conv_bn_relu(channel, | |||||
| depth, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.atrous_rates = [] | |||||
| if atrous_rates is not None: | |||||
| self.atrous_rates = atrous_rates | |||||
| self.aspp_pointwise = _conv_bn_relu(channel, | |||||
| depth, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel, | |||||
| channel_multiplier=1, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| dilation=1, | |||||
| pad_mode="valid") | |||||
| self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm) | |||||
| self.aspp_depth_relu = nn.ReLU() | |||||
| self.aspp_depths = [] | |||||
| self.aspp_depth_spacetobatchs = [] | |||||
| self.aspp_depth_batchtospaces = [] | |||||
| for scale_size in scale_sizes: | |||||
| aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16) | |||||
| if atrous_rates is None: | |||||
| break | |||||
| for rate in atrous_rates: | |||||
| padding = 0 | |||||
| for j in range(100): | |||||
| padded_size = rate * j | |||||
| if padded_size >= aspp_scale_depth_size + 2 * rate: | |||||
| padding = padded_size - aspp_scale_depth_size - 2 * rate | |||||
| break | |||||
| paddings = [[rate, rate + int(padding)], | |||||
| [rate, rate + int(padding)]] | |||||
| self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings) | |||||
| self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch) | |||||
| crops = [[0, int(padding)], [0, int(padding)]] | |||||
| self.aspp_depth_batchtospace = BatchToSpace(rate, crops) | |||||
| self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace) | |||||
| self.aspp_depths = nn.CellList(self.aspp_depths) | |||||
| self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs) | |||||
| self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces) | |||||
| self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1]))) | |||||
| self.global_poolings = [] | |||||
| for scale_size in scale_sizes: | |||||
| pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride) | |||||
| pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride) | |||||
| self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w)))) | |||||
| self.global_poolings = nn.CellList(self.global_poolings) | |||||
| self.conv_bn = _conv_bn_relu(channel, | |||||
| depth, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.samples = [] | |||||
| for scale_size in scale_sizes: | |||||
| self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride)) | |||||
| self.samples = nn.CellList(self.samples) | |||||
| self.feature_shape = feature_shape | |||||
| self.concat = P.Concat(axis=1) | |||||
| def construct(self, x, scale_index=0): | |||||
| aspp0 = self.aspp0(x) | |||||
| aspp1 = self.global_poolings[scale_index](x) | |||||
| aspp1 = self.conv_bn(aspp1) | |||||
| aspp1 = self.samples[scale_index](aspp1) | |||||
| output = self.concat((aspp1, aspp0)) | |||||
| for i in range(len(self.atrous_rates)): | |||||
| aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x) | |||||
| aspp_i = self.aspp_depth_depthwiseconv(aspp_i) | |||||
| aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i) | |||||
| aspp_i = self.aspp_depth_bn(aspp_i) | |||||
| aspp_i = self.aspp_depth_relu(aspp_i) | |||||
| aspp_i = self.aspp_pointwise(aspp_i) | |||||
| output = self.concat((output, aspp_i)) | |||||
| return output | |||||
| class DecoderSampleBlock(nn.Cell): | |||||
| """Decoder sample block.""" | |||||
| def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4): | |||||
| super(DecoderSampleBlock, self).__init__() | |||||
| sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1 | |||||
| sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1 | |||||
| self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) | |||||
| def construct(self, x): | |||||
| return self.sample(x) | |||||
| class Decoder(nn.Cell): | |||||
| """ | |||||
| Decode module for DeepLabv3. | |||||
| Args: | |||||
| low_level_channel (int): Low level input channel | |||||
| channel (int): Input channel. | |||||
| depth (int): Output channel. | |||||
| feature_shape (list): 'Input image shape, [N,C,H,W].' | |||||
| scale_sizes (list): 'Input scales for multi-scale feature extraction.' | |||||
| decoder_output_stride (int): 'The ratio of input to output spatial resolution' | |||||
| fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> Decoder(256, 100, [56,56]) | |||||
| """ | |||||
| def __init__(self, | |||||
| low_level_channel, | |||||
| channel, | |||||
| depth, | |||||
| feature_shape, | |||||
| scale_sizes, | |||||
| decoder_output_stride, | |||||
| fine_tune_batch_norm): | |||||
| super(Decoder, self).__init__() | |||||
| self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1, | |||||
| pad_mode="same", use_batch_statistics=fine_tune_batch_norm) | |||||
| self.decoder_depth0 = _deep_conv_bn_relu(channel + 48, | |||||
| channel_multiplier=1, | |||||
| ksize=3, | |||||
| stride=1, | |||||
| pad_mode="same", | |||||
| dilation=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.decoder_pointwise0 = _conv_bn_relu(channel + 48, | |||||
| depth, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.decoder_depth1 = _deep_conv_bn_relu(depth, | |||||
| channel_multiplier=1, | |||||
| ksize=3, | |||||
| stride=1, | |||||
| pad_mode="same", | |||||
| dilation=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.decoder_pointwise1 = _conv_bn_relu(depth, | |||||
| depth, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.depth = depth | |||||
| self.concat = P.Concat(axis=1) | |||||
| self.samples = [] | |||||
| for scale_size in scale_sizes: | |||||
| self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride)) | |||||
| self.samples = nn.CellList(self.samples) | |||||
| def construct(self, x, low_level_feature, scale_index): | |||||
| low_level_feature = self.feature_projection(low_level_feature) | |||||
| low_level_feature = self.samples[scale_index](low_level_feature) | |||||
| x = self.samples[scale_index](x) | |||||
| output = self.concat((x, low_level_feature)) | |||||
| output = self.decoder_depth0(output) | |||||
| output = self.decoder_pointwise0(output) | |||||
| output = self.decoder_depth1(output) | |||||
| output = self.decoder_pointwise1(output) | |||||
| return output | |||||
| class SingleDeepLabV3(nn.Cell): | |||||
| """ | |||||
| DeepLabv3 Network. | |||||
| Args: | |||||
| num_classes (int): Class number. | |||||
| feature_shape (list): Input image shape, [N,C,H,W]. | |||||
| backbone (Cell): Backbone Network. | |||||
| channel (int): Resnet output channel. | |||||
| depth (int): ASPP block depth. | |||||
| scale_sizes (list): Input scales for multi-scale feature extraction. | |||||
| atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. | |||||
| decoder_output_stride (int): 'The ratio of input to output spatial resolution' | |||||
| output_stride (int): 'The ratio of input to output spatial resolution.' | |||||
| fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> SingleDeepLabV3(num_classes=10, | |||||
| >>> feature_shape=[1,3,224,224], | |||||
| >>> backbone=resnet50_dl(), | |||||
| >>> channel=2048, | |||||
| >>> depth=256) | |||||
| >>> scale_sizes=[1.0]) | |||||
| >>> atrous_rates=[6]) | |||||
| >>> decoder_output_stride=4) | |||||
| >>> output_stride=16) | |||||
| """ | |||||
| def __init__(self, | |||||
| num_classes, | |||||
| feature_shape, | |||||
| backbone, | |||||
| channel, | |||||
| depth, | |||||
| scale_sizes, | |||||
| atrous_rates, | |||||
| decoder_output_stride, | |||||
| output_stride, | |||||
| fine_tune_batch_norm=False): | |||||
| super(SingleDeepLabV3, self).__init__() | |||||
| self.num_classes = num_classes | |||||
| self.channel = channel | |||||
| self.depth = depth | |||||
| self.scale_sizes = [] | |||||
| for scale_size in np.sort(scale_sizes): | |||||
| self.scale_sizes.append(scale_size) | |||||
| self.net = backbone | |||||
| self.aspp = ASPP(channel=self.channel, | |||||
| depth=self.depth, | |||||
| feature_shape=[feature_shape[2], | |||||
| feature_shape[3]], | |||||
| scale_sizes=self.scale_sizes, | |||||
| atrous_rates=atrous_rates, | |||||
| output_stride=output_stride, | |||||
| fine_tune_batch_norm=fine_tune_batch_norm) | |||||
| self.aspp.add_flags(loop_can_unroll=True) | |||||
| atrous_rates_len = 0 | |||||
| if atrous_rates is not None: | |||||
| atrous_rates_len = len(atrous_rates) | |||||
| self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| use_batch_statistics=fine_tune_batch_norm) | |||||
| self.fc2 = nn.Conv2d(depth, | |||||
| num_classes, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| has_bias=True) | |||||
| self.upsample = P.ResizeBilinear((int(feature_shape[2]), | |||||
| int(feature_shape[3])), | |||||
| align_corners=True) | |||||
| self.samples = [] | |||||
| for scale_size in self.scale_sizes: | |||||
| self.samples.append(SampleBlock(feature_shape, scale_size)) | |||||
| self.samples = nn.CellList(self.samples) | |||||
| self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]), | |||||
| float(feature_shape[3])] | |||||
| self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1))) | |||||
| self.dropout = nn.Dropout(keep_prob=0.9) | |||||
| self.shape = P.Shape() | |||||
| self.decoder_output_stride = decoder_output_stride | |||||
| if decoder_output_stride is not None: | |||||
| self.decoder = Decoder(low_level_channel=depth, | |||||
| channel=depth, | |||||
| depth=depth, | |||||
| feature_shape=[feature_shape[2], | |||||
| feature_shape[3]], | |||||
| scale_sizes=self.scale_sizes, | |||||
| decoder_output_stride=decoder_output_stride, | |||||
| fine_tune_batch_norm=fine_tune_batch_norm) | |||||
| def construct(self, x, scale_index=0): | |||||
| x = (2.0 / 255.0) * x - 1.0 | |||||
| x = self.pad(x) | |||||
| low_level_feature, feature_map = self.net(x) | |||||
| for scale_size in self.scale_sizes: | |||||
| if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2: | |||||
| output = self.aspp(feature_map, scale_index) | |||||
| output = self.fc1(output) | |||||
| if self.decoder_output_stride is not None: | |||||
| output = self.decoder(output, low_level_feature, scale_index) | |||||
| output = self.fc2(output) | |||||
| output = self.samples[scale_index](output) | |||||
| return output | |||||
| scale_index += 1 | |||||
| return feature_map | |||||
| class SampleBlock(nn.Cell): | |||||
| """Sample block.""" | |||||
| def __init__(self, | |||||
| feature_shape, | |||||
| scale_size=1.0): | |||||
| super(SampleBlock, self).__init__() | |||||
| sample_h = np.ceil(float(feature_shape[2]) * scale_size) | |||||
| sample_w = np.ceil(float(feature_shape[3]) * scale_size) | |||||
| self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) | |||||
| def construct(self, x): | |||||
| return self.sample(x) | |||||
| class DeepLabV3(nn.Cell): | |||||
| """DeepLabV3 model.""" | |||||
| def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates, | |||||
| decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid): | |||||
| super(DeepLabV3, self).__init__() | |||||
| self.infer_scale_sizes = [] | |||||
| if infer_scale_sizes is not None: | |||||
| self.infer_scale_sizes = infer_scale_sizes | |||||
| self.infer_scale_sizes = infer_scale_sizes | |||||
| if image_pyramid is None: | |||||
| image_pyramid = [1.0] | |||||
| self.image_pyramid = image_pyramid | |||||
| scale_sizes = [] | |||||
| for pyramid in image_pyramid: | |||||
| scale_sizes.append(pyramid) | |||||
| for scale in infer_scale_sizes: | |||||
| scale_sizes.append(scale) | |||||
| self.samples = [] | |||||
| for scale_size in scale_sizes: | |||||
| self.samples.append(SampleBlock(feature_shape, scale_size)) | |||||
| self.samples = nn.CellList(self.samples) | |||||
| self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes, | |||||
| feature_shape=feature_shape, | |||||
| backbone=resnet50_dl(fine_tune_batch_norm), | |||||
| channel=channel, | |||||
| depth=depth, | |||||
| scale_sizes=scale_sizes, | |||||
| atrous_rates=atrous_rates, | |||||
| decoder_output_stride=decoder_output_stride, | |||||
| output_stride=output_stride, | |||||
| fine_tune_batch_norm=fine_tune_batch_norm) | |||||
| self.softmax = P.Softmax(axis=1) | |||||
| self.concat = P.Concat(axis=2) | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.reduce_mean = P.ReduceMean() | |||||
| self.sample_common = P.ResizeBilinear((int(feature_shape[2]), | |||||
| int(feature_shape[3])), | |||||
| align_corners=True) | |||||
| def construct(self, x): | |||||
| logits = () | |||||
| if self.training: | |||||
| if len(self.image_pyramid) >= 1: | |||||
| if self.image_pyramid[0] == 1: | |||||
| logits = self.deeplabv3(x) | |||||
| else: | |||||
| x1 = self.samples[0](x) | |||||
| logits = self.deeplabv3(x1) | |||||
| logits = self.sample_common(logits) | |||||
| logits = self.expand_dims(logits, 2) | |||||
| for i in range(len(self.image_pyramid) - 1): | |||||
| x_i = self.samples[i + 1](x) | |||||
| logits_i = self.deeplabv3(x_i) | |||||
| logits_i = self.sample_common(logits_i) | |||||
| logits_i = self.expand_dims(logits_i, 2) | |||||
| logits = self.concat((logits, logits_i)) | |||||
| logits = self.reduce_mean(logits, 2) | |||||
| return logits | |||||
| if len(self.infer_scale_sizes) >= 1: | |||||
| infer_index = len(self.image_pyramid) | |||||
| x1 = self.samples[infer_index](x) | |||||
| logits = self.deeplabv3(x1) | |||||
| logits = self.sample_common(logits) | |||||
| logits = self.softmax(logits) | |||||
| logits = self.expand_dims(logits, 2) | |||||
| for i in range(len(self.infer_scale_sizes) - 1): | |||||
| x_i = self.samples[i + 1 + infer_index](x) | |||||
| logits_i = self.deeplabv3(x_i) | |||||
| logits_i = self.sample_common(logits_i) | |||||
| logits_i = self.softmax(logits_i) | |||||
| logits_i = self.expand_dims(logits_i, 2) | |||||
| logits = self.concat((logits, logits_i)) | |||||
| logits = self.reduce_mean(logits, 2) | |||||
| return logits | |||||
| def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid, | |||||
| infer_scale_sizes, atrous_rates=None, decoder_output_stride=None, | |||||
| output_stride=16, fine_tune_batch_norm=False): | |||||
| """ | |||||
| ResNet50 based DeepLabv3 network. | |||||
| Args: | |||||
| num_classes (int): Class number. | |||||
| feature_shape (list): Input image shape, [N,C,H,W]. | |||||
| image_pyramid (list): Input scales for multi-scale feature extraction. | |||||
| atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. | |||||
| infer_scale_sizes (list): 'The scales to resize images for inference. | |||||
| decoder_output_stride (int): 'The ratio of input to output spatial resolution' | |||||
| output_stride (int): 'The ratio of input to output spatial resolution.' | |||||
| fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' | |||||
| Returns: | |||||
| Cell, cell instance of ResNet50 based DeepLabv3 neural network. | |||||
| Examples: | |||||
| >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0]) | |||||
| """ | |||||
| return DeepLabV3(num_classes=num_classes, | |||||
| feature_shape=feature_shape, | |||||
| backbone=resnet50_dl(fine_tune_batch_norm), | |||||
| channel=2048, | |||||
| depth=256, | |||||
| infer_scale_sizes=infer_scale_sizes, | |||||
| atrous_rates=atrous_rates, | |||||
| decoder_output_stride=decoder_output_stride, | |||||
| output_stride=output_stride, | |||||
| fine_tune_batch_norm=fine_tune_batch_norm, | |||||
| image_pyramid=image_pyramid) | |||||
| @@ -0,0 +1,84 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Process Dataset.""" | |||||
| import abc | |||||
| import os | |||||
| import time | |||||
| from .utils.adapter import get_raw_samples, read_image | |||||
| class BaseDataset: | |||||
| """ | |||||
| Create dataset. | |||||
| Args: | |||||
| data_url (str): The path of data. | |||||
| usage (str): Whether to use train or eval (default='train'). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| def __init__(self, data_url, usage): | |||||
| self.data_url = data_url | |||||
| self.usage = usage | |||||
| self.cur_index = 0 | |||||
| self.samples = [] | |||||
| _s_time = time.time() | |||||
| self._load_samples() | |||||
| _e_time = time.time() | |||||
| print(f"load samples success~, time cost = {_e_time - _s_time}") | |||||
| def __getitem__(self, item): | |||||
| sample = self.samples[item] | |||||
| return self._next_data(sample) | |||||
| def __len__(self): | |||||
| return len(self.samples) | |||||
| @staticmethod | |||||
| def _next_data(sample): | |||||
| image_path = sample[0] | |||||
| mask_image_path = sample[1] | |||||
| image = read_image(image_path) | |||||
| mask_image = read_image(mask_image_path) | |||||
| return [image, mask_image] | |||||
| @abc.abstractmethod | |||||
| def _load_samples(self): | |||||
| pass | |||||
| class HwVocRawDataset(BaseDataset): | |||||
| """ | |||||
| Create dataset with raw data. | |||||
| Args: | |||||
| data_url (str): The path of data. | |||||
| usage (str): Whether to use train or eval (default='train'). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| def __init__(self, data_url, usage="train"): | |||||
| super().__init__(data_url, usage) | |||||
| def _load_samples(self): | |||||
| try: | |||||
| self.samples = get_raw_samples(os.path.join(self.data_url, self.usage)) | |||||
| except Exception as e: | |||||
| print("load HwVocRawDataset failed!!!") | |||||
| raise e | |||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """OhemLoss.""" | |||||
| import mindspore.nn as nn | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| class OhemLoss(nn.Cell): | |||||
| """Ohem loss cell.""" | |||||
| def __init__(self, num, ignore_label): | |||||
| super(OhemLoss, self).__init__() | |||||
| self.mul = P.Mul() | |||||
| self.shape = P.Shape() | |||||
| self.one_hot = nn.OneHot(-1, num, 1.0, 0.0) | |||||
| self.squeeze = P.Squeeze() | |||||
| self.num = num | |||||
| self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean() | |||||
| self.select = P.Select() | |||||
| self.reshape = P.Reshape() | |||||
| self.cast = P.Cast() | |||||
| self.not_equal = P.NotEqual() | |||||
| self.equal = P.Equal() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.fill = P.Fill() | |||||
| self.transpose = P.Transpose() | |||||
| self.ignore_label = ignore_label | |||||
| self.loss_weight = 1.0 | |||||
| def construct(self, logits, labels): | |||||
| logits = self.transpose(logits, (0, 2, 3, 1)) | |||||
| logits = self.reshape(logits, (-1, self.num)) | |||||
| labels = F.cast(labels, mstype.int32) | |||||
| labels = self.reshape(labels, (-1,)) | |||||
| one_hot_labels = self.one_hot(labels) | |||||
| losses = self.cross_entropy(logits, one_hot_labels)[0] | |||||
| weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight | |||||
| weighted_losses = self.mul(losses, weights) | |||||
| loss = self.reduce_sum(weighted_losses, (0,)) | |||||
| zeros = self.fill(mstype.float32, self.shape(weights), 0.0) | |||||
| ones = self.fill(mstype.float32, self.shape(weights), 1.0) | |||||
| present = self.select(self.equal(weights, zeros), zeros, ones) | |||||
| present = self.reduce_sum(present, (0,)) | |||||
| zeros = self.fill(mstype.float32, self.shape(present), 0.0) | |||||
| min_control = self.fill(mstype.float32, self.shape(present), 1.0) | |||||
| present = self.select(self.equal(present, zeros), min_control, present) | |||||
| loss = loss / present | |||||
| return loss | |||||
| @@ -0,0 +1,116 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Dataset module.""" | |||||
| from PIL import Image | |||||
| import mindspore.dataset as de | |||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| import numpy as np | |||||
| from .ei_dataset import HwVocRawDataset | |||||
| from .utils import custom_transforms as tr | |||||
| class DataTransform: | |||||
| """Transform dataset for DeepLabV3.""" | |||||
| def __init__(self, args, usage): | |||||
| self.args = args | |||||
| self.usage = usage | |||||
| def __call__(self, image, label): | |||||
| if self.usage == "train": | |||||
| return self._train(image, label) | |||||
| if self.usage == "eval": | |||||
| return self._eval(image, label) | |||||
| return None | |||||
| def _train(self, image, label): | |||||
| """ | |||||
| Process training data. | |||||
| Args: | |||||
| image (list): Image data. | |||||
| label (list): Dataset label. | |||||
| """ | |||||
| image = Image.fromarray(image) | |||||
| label = Image.fromarray(label) | |||||
| rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size) | |||||
| image, label = rsc_tr(image, label) | |||||
| rhf_tr = tr.RandomHorizontalFlip() | |||||
| image, label = rhf_tr(image, label) | |||||
| image = np.array(image).astype(np.float32) | |||||
| label = np.array(label).astype(np.float32) | |||||
| return image, label | |||||
| def _eval(self, image, label): | |||||
| """ | |||||
| Process eval data. | |||||
| Args: | |||||
| image (list): Image data. | |||||
| label (list): Dataset label. | |||||
| """ | |||||
| image = Image.fromarray(image) | |||||
| label = Image.fromarray(label) | |||||
| fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size) | |||||
| image, label = fsc_tr(image, label) | |||||
| image = np.array(image).astype(np.float32) | |||||
| label = np.array(label).astype(np.float32) | |||||
| return image, label | |||||
| def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shuffle=True): | |||||
| """ | |||||
| Create Dataset for DeepLabV3. | |||||
| Args: | |||||
| args (dict): Train parameters. | |||||
| data_url (str): Dataset path. | |||||
| epoch_num (int): Epoch of dataset (default=1). | |||||
| batch_size (int): Batch size of dataset (default=1). | |||||
| usage (str): Whether is use to train or eval (default='train'). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| # create iter dataset | |||||
| dataset = HwVocRawDataset(data_url, usage=usage) | |||||
| dataset_len = len(dataset) | |||||
| # wrapped with GeneratorDataset | |||||
| dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None) | |||||
| dataset.set_dataset_size(dataset_len) | |||||
| dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage)) | |||||
| channelswap_op = C.HWC2CHW() | |||||
| dataset = dataset.map(input_columns="image", operations=channelswap_op) | |||||
| # 1464 samples / batch_size 8 = 183 batches | |||||
| # epoch_num is num of steps | |||||
| # 3658 steps / 183 = 20 epochs | |||||
| if usage == "train" and shuffle: | |||||
| dataset = dataset.shuffle(1464) | |||||
| dataset = dataset.batch(batch_size, drop_remainder=(usage == "train")) | |||||
| dataset = dataset.repeat(count=epoch_num) | |||||
| dataset.map_model = 4 | |||||
| return dataset | |||||
| @@ -0,0 +1,72 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """mIou.""" | |||||
| import numpy as np | |||||
| from mindspore.nn.metrics.metric import Metric | |||||
| def confuse_matrix(target, pred, n): | |||||
| k = (target >= 0) & (target < n) | |||||
| return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n) | |||||
| def iou(hist): | |||||
| denominator = hist.sum(1) + hist.sum(0) - np.diag(hist) | |||||
| res = np.diag(hist) / np.where(denominator > 0, denominator, 1) | |||||
| res = np.sum(res) / np.count_nonzero(denominator) | |||||
| return res | |||||
| class MiouPrecision(Metric): | |||||
| """Calculate miou precision.""" | |||||
| def __init__(self, num_class=21): | |||||
| super(MiouPrecision, self).__init__() | |||||
| if not isinstance(num_class, int): | |||||
| raise TypeError('num_class should be integer type, but got {}'.format(type(num_class))) | |||||
| if num_class < 1: | |||||
| raise ValueError('num_class must be at least 1, but got {}'.format(num_class)) | |||||
| self._num_class = num_class | |||||
| self._mIoU = [] | |||||
| self.clear() | |||||
| def clear(self): | |||||
| self._hist = np.zeros((self._num_class, self._num_class)) | |||||
| self._mIoU = [] | |||||
| def update(self, *inputs): | |||||
| if len(inputs) != 2: | |||||
| raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||||
| predict_in = self._convert_data(inputs[0]) | |||||
| label_in = self._convert_data(inputs[1]) | |||||
| if predict_in.shape[1] != self._num_class: | |||||
| raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} ' | |||||
| 'classes'.format(self._num_class, predict_in.shape[1])) | |||||
| pred = np.argmax(predict_in, axis=1) | |||||
| label = label_in | |||||
| if len(label.flatten()) != len(pred.flatten()): | |||||
| print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten()))) | |||||
| raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} ' | |||||
| 'classes'.format(self._num_class, predict_in.shape[1])) | |||||
| self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class) | |||||
| mIoUs = iou(self._hist) | |||||
| self._mIoU.append(mIoUs) | |||||
| def eval(self): | |||||
| """ | |||||
| Computes the mIoU categorical accuracy. | |||||
| """ | |||||
| mIoU = np.nanmean(self._mIoU) | |||||
| print('mIoU = {}'.format(mIoU)) | |||||
| return mIoU | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Adapter dataset.""" | |||||
| import fnmatch | |||||
| import io | |||||
| import os | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| from ..utils import file_io | |||||
| def get_raw_samples(data_url): | |||||
| """ | |||||
| Get dataset from raw data. | |||||
| Args: | |||||
| data_url (str): Dataset path. | |||||
| Returns: | |||||
| list, a file list. | |||||
| """ | |||||
| def _list_files(dir_path, pattern): | |||||
| full_files = [] | |||||
| _, _, files = next(file_io.walk(dir_path)) | |||||
| for f in files: | |||||
| if fnmatch.fnmatch(f.lower(), pattern.lower()): | |||||
| full_files.append(os.path.join(dir_path, f)) | |||||
| return full_files | |||||
| img_files = _list_files(os.path.join(data_url, "Images"), "*.jpg") | |||||
| seg_files = _list_files(os.path.join(data_url, "SegmentationClassRaw"), "*.png") | |||||
| files = [] | |||||
| for img_file in img_files: | |||||
| _, file_name = os.path.split(img_file) | |||||
| name, _ = os.path.splitext(file_name) | |||||
| seg_file = os.path.join(data_url, "SegmentationClassRaw", ".".join([name, "png"])) | |||||
| if seg_file in seg_files: | |||||
| files.append([img_file, seg_file]) | |||||
| return files | |||||
| def read_image(img_path): | |||||
| """ | |||||
| Read image from file. | |||||
| Args: | |||||
| img_path (str): image path. | |||||
| """ | |||||
| img = file_io.read(img_path.strip(), binary=True) | |||||
| data = io.BytesIO(img) | |||||
| img = Image.open(data) | |||||
| return np.array(img) | |||||
| @@ -0,0 +1,149 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Random process dataset.""" | |||||
| import random | |||||
| import numpy as np | |||||
| from PIL import Image, ImageOps, ImageFilter | |||||
| class Normalize: | |||||
| """Normalize a tensor image with mean and standard deviation. | |||||
| Args: | |||||
| mean (tuple): means for each channel. | |||||
| std (tuple): standard deviations for each channel. | |||||
| """ | |||||
| def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): | |||||
| self.mean = mean | |||||
| self.std = std | |||||
| def __call__(self, img, mask): | |||||
| img = np.array(img).astype(np.float32) | |||||
| mask = np.array(mask).astype(np.float32) | |||||
| img = ((img - self.mean) / self.std).astype(np.float32) | |||||
| return img, mask | |||||
| class RandomHorizontalFlip: | |||||
| """Randomly decide whether to horizontal flip.""" | |||||
| def __call__(self, img, mask): | |||||
| if random.random() < 0.5: | |||||
| img = img.transpose(Image.FLIP_LEFT_RIGHT) | |||||
| mask = mask.transpose(Image.FLIP_LEFT_RIGHT) | |||||
| return img, mask | |||||
| class RandomRotate: | |||||
| """ | |||||
| Randomly decide whether to rotate. | |||||
| Args: | |||||
| degree (float): The degree of rotate. | |||||
| """ | |||||
| def __init__(self, degree): | |||||
| self.degree = degree | |||||
| def __call__(self, img, mask): | |||||
| rotate_degree = random.uniform(-1 * self.degree, self.degree) | |||||
| img = img.rotate(rotate_degree, Image.BILINEAR) | |||||
| mask = mask.rotate(rotate_degree, Image.NEAREST) | |||||
| return img, mask | |||||
| class RandomGaussianBlur: | |||||
| """Randomly decide whether to filter image with gaussian blur.""" | |||||
| def __call__(self, img, mask): | |||||
| if random.random() < 0.5: | |||||
| img = img.filter(ImageFilter.GaussianBlur( | |||||
| radius=random.random())) | |||||
| return img, mask | |||||
| class RandomScaleCrop: | |||||
| """Randomly decide whether to scale and crop image.""" | |||||
| def __init__(self, base_size, crop_size, fill=0): | |||||
| self.base_size = base_size | |||||
| self.crop_size = crop_size | |||||
| self.fill = fill | |||||
| def __call__(self, img, mask): | |||||
| # random scale (short edge) | |||||
| short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) | |||||
| w, h = img.size | |||||
| if h > w: | |||||
| ow = short_size | |||||
| oh = int(1.0 * h * ow / w) | |||||
| else: | |||||
| oh = short_size | |||||
| ow = int(1.0 * w * oh / h) | |||||
| img = img.resize((ow, oh), Image.BILINEAR) | |||||
| mask = mask.resize((ow, oh), Image.NEAREST) | |||||
| # pad crop | |||||
| if short_size < self.crop_size: | |||||
| padh = self.crop_size - oh if oh < self.crop_size else 0 | |||||
| padw = self.crop_size - ow if ow < self.crop_size else 0 | |||||
| img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) | |||||
| mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) | |||||
| # random crop crop_size | |||||
| w, h = img.size | |||||
| x1 = random.randint(0, w - self.crop_size) | |||||
| y1 = random.randint(0, h - self.crop_size) | |||||
| img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) | |||||
| mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) | |||||
| return img, mask | |||||
| class FixScaleCrop: | |||||
| """Scale and crop image with fixing size.""" | |||||
| def __init__(self, crop_size): | |||||
| self.crop_size = crop_size | |||||
| def __call__(self, img, mask): | |||||
| w, h = img.size | |||||
| if w > h: | |||||
| oh = self.crop_size | |||||
| ow = int(1.0 * w * oh / h) | |||||
| else: | |||||
| ow = self.crop_size | |||||
| oh = int(1.0 * h * ow / w) | |||||
| img = img.resize((ow, oh), Image.BILINEAR) | |||||
| mask = mask.resize((ow, oh), Image.NEAREST) | |||||
| # center crop | |||||
| w, h = img.size | |||||
| x1 = int(round((w - self.crop_size) / 2.)) | |||||
| y1 = int(round((h - self.crop_size) / 2.)) | |||||
| img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) | |||||
| mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) | |||||
| return img, mask | |||||
| class FixedResize: | |||||
| """Resize image with fixing size.""" | |||||
| def __init__(self, size): | |||||
| self.size = (size, size) | |||||
| def __call__(self, img, mask): | |||||
| assert img.size == mask.size | |||||
| img = img.resize(self.size, Image.BILINEAR) | |||||
| mask = mask.resize(self.size, Image.NEAREST) | |||||
| return img, mask | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """File operation module.""" | |||||
| import os | |||||
| def _is_obs(url): | |||||
| return url.startswith("obs://") or url.startswith("s3://") | |||||
| def read(url, binary=False): | |||||
| if _is_obs(url): | |||||
| # TODO read cloud file. | |||||
| return None | |||||
| with open(url, "rb" if binary else "r") as f: | |||||
| return f.read() | |||||
| def walk(url): | |||||
| if _is_obs(url): | |||||
| # TODO read cloud file. | |||||
| return None | |||||
| return os.walk(url) | |||||
| @@ -0,0 +1,102 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train.""" | |||||
| import argparse | |||||
| import time | |||||
| import pytest | |||||
| import numpy as np | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.nn.optim.momentum import Momentum | |||||
| from mindspore import Model | |||||
| from mindspore.train.callback import Callback | |||||
| from src.md_dataset import create_dataset | |||||
| from src.losses import OhemLoss | |||||
| from src.deeplabv3 import deeplabv3_resnet50 | |||||
| from src.config import config | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| #--train | |||||
| #--eval | |||||
| # --Images | |||||
| # --2008_001135.jpg | |||||
| # --2008_001404.jpg | |||||
| # --SegmentationClassRaw | |||||
| # --2008_001135.png | |||||
| # --2008_001404.png | |||||
| data_url = "/home/workspace/mindspore_dataset/voc/voc2012" | |||||
| class LossCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| Note: | |||||
| if per_print_times is 0 do not print loss. | |||||
| Args: | |||||
| per_print_times (int): Print loss every times. Default: 1. | |||||
| """ | |||||
| def __init__(self, data_size, per_print_times=1): | |||||
| super(LossCallBack, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0") | |||||
| self.data_size = data_size | |||||
| self._per_print_times = per_print_times | |||||
| self.time = 1000 | |||||
| self.loss = 0 | |||||
| def epoch_begin(self, run_context): | |||||
| self.epoch_time = time.time() | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| self.time = epoch_mseconds / self.data_size | |||||
| self.loss = cb_params.net_outputs | |||||
| print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | |||||
| str(cb_params.net_outputs))) | |||||
| def model_fine_tune(train_net, fix_weight_layer): | |||||
| for para in train_net.trainable_params(): | |||||
| para.set_parameter_data(Tensor(np.ones(para.data.shape).astype(np.float32) * 0.02)) | |||||
| if fix_weight_layer in para.name: | |||||
| para.requires_grad = False | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_deeplabv3_1p(): | |||||
| start_time = time.time() | |||||
| epoch_size = 100 | |||||
| args_opt = argparse.Namespace(base_size=513, crop_size=513, batch_size=2) | |||||
| args_opt.base_size = config.crop_size | |||||
| args_opt.crop_size = config.crop_size | |||||
| args_opt.batch_size = config.batch_size | |||||
| train_dataset = create_dataset(args_opt, data_url, epoch_size, config.batch_size, | |||||
| usage="eval") | |||||
| dataset_size = train_dataset.get_dataset_size() | |||||
| callback = LossCallBack(dataset_size) | |||||
| net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], | |||||
| infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, | |||||
| decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, | |||||
| fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) | |||||
| net.set_train() | |||||
| model_fine_tune(net, 'layer') | |||||
| loss = OhemLoss(config.seg_num_classes, config.ignore_label) | |||||
| opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) | |||||
| model = Model(net, loss, opt) | |||||
| model.train(epoch_size, train_dataset, callback) | |||||
| print(time.time() - start_time) | |||||
| print("expect loss: ", callback.loss) | |||||
| print("expect time: ", callback.time) | |||||
| expect_loss = 0.92 | |||||
| expect_time = 40 | |||||
| assert callback.loss.asnumpy() <= expect_loss | |||||
| assert callback.time <= expect_time | |||||