|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the License);
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # httpwww.apache.orglicensesLICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an AS IS BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """DeepLabv3."""
-
- import numpy as np
- import mindspore.nn as nn
- from mindspore.ops import operations as P
- from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
- DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
-
-
- class ASPPSampleBlock(nn.Cell):
- """ASPP sample block."""
- def __init__(self, feature_shape, scale_size, output_stride):
- super(ASPPSampleBlock, self).__init__()
- sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1
- sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1
- self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
-
- def construct(self, x):
- return self.sample(x)
-
-
- class ASPP(nn.Cell):
- """
- ASPP model for DeepLabv3.
-
- Args:
- channel (int): Input channel.
- depth (int): Output channel.
- feature_shape (list): The shape of feature,[h,w].
- scale_sizes (list): Input scales for multi-scale feature extraction.
- atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
- output_stride (int): 'The ratio of input to output spatial resolution.'
- fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
-
- Returns:
- Tensor, output tensor.
-
- Examples:
- >>> ASPP(channel=2048,256,[14,14],[1],[6],16)
- """
- def __init__(self, channel, depth, feature_shape, scale_sizes,
- atrous_rates, output_stride, fine_tune_batch_norm=False):
- super(ASPP, self).__init__()
- self.aspp0 = _conv_bn_relu(channel,
- depth,
- ksize=1,
- stride=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.atrous_rates = []
- if atrous_rates is not None:
- self.atrous_rates = atrous_rates
- self.aspp_pointwise = _conv_bn_relu(channel,
- depth,
- ksize=1,
- stride=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel,
- channel_multiplier=1,
- kernel_size=3,
- stride=1,
- dilation=1,
- pad_mode="valid")
- self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm)
- self.aspp_depth_relu = nn.ReLU()
- self.aspp_depths = []
- self.aspp_depth_spacetobatchs = []
- self.aspp_depth_batchtospaces = []
-
- for scale_size in scale_sizes:
- aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16)
- if atrous_rates is None:
- break
- for rate in atrous_rates:
- padding = 0
- for j in range(100):
- padded_size = rate * j
- if padded_size >= aspp_scale_depth_size + 2 * rate:
- padding = padded_size - aspp_scale_depth_size - 2 * rate
- break
- paddings = [[rate, rate + int(padding)],
- [rate, rate + int(padding)]]
- self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings)
- self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch)
- crops = [[0, int(padding)], [0, int(padding)]]
- self.aspp_depth_batchtospace = BatchToSpace(rate, crops)
- self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace)
- self.aspp_depths = nn.CellList(self.aspp_depths)
- self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs)
- self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces)
-
- self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1])))
- self.global_poolings = []
- for scale_size in scale_sizes:
- pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride)
- pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride)
- self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w))))
- self.global_poolings = nn.CellList(self.global_poolings)
- self.conv_bn = _conv_bn_relu(channel,
- depth,
- ksize=1,
- stride=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.samples = []
- for scale_size in scale_sizes:
- self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride))
- self.samples = nn.CellList(self.samples)
- self.feature_shape = feature_shape
- self.concat = P.Concat(axis=1)
-
- def construct(self, x, scale_index=0):
- aspp0 = self.aspp0(x)
- aspp1 = self.global_poolings[scale_index](x)
- aspp1 = self.conv_bn(aspp1)
- aspp1 = self.samples[scale_index](aspp1)
- output = self.concat((aspp1, aspp0))
-
- for i in range(len(self.atrous_rates)):
- aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x)
- aspp_i = self.aspp_depth_depthwiseconv(aspp_i)
- aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i)
- aspp_i = self.aspp_depth_bn(aspp_i)
- aspp_i = self.aspp_depth_relu(aspp_i)
- aspp_i = self.aspp_pointwise(aspp_i)
- output = self.concat((output, aspp_i))
- return output
-
-
- class DecoderSampleBlock(nn.Cell):
- """Decoder sample block."""
- def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4):
- super(DecoderSampleBlock, self).__init__()
- sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1
- sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1
- self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
-
- def construct(self, x):
- return self.sample(x)
-
-
- class Decoder(nn.Cell):
- """
- Decode module for DeepLabv3.
- Args:
- low_level_channel (int): Low level input channel
- channel (int): Input channel.
- depth (int): Output channel.
- feature_shape (list): 'Input image shape, [N,C,H,W].'
- scale_sizes (list): 'Input scales for multi-scale feature extraction.'
- decoder_output_stride (int): 'The ratio of input to output spatial resolution'
- fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
- Returns:
- Tensor, output tensor.
- Examples:
- >>> Decoder(256, 100, [56,56])
- """
- def __init__(self,
- low_level_channel,
- channel,
- depth,
- feature_shape,
- scale_sizes,
- decoder_output_stride,
- fine_tune_batch_norm):
- super(Decoder, self).__init__()
- self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1,
- pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
- self.decoder_depth0 = _deep_conv_bn_relu(channel + 48,
- channel_multiplier=1,
- ksize=3,
- stride=1,
- pad_mode="same",
- dilation=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.decoder_pointwise0 = _conv_bn_relu(channel + 48,
- depth,
- ksize=1,
- stride=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.decoder_depth1 = _deep_conv_bn_relu(depth,
- channel_multiplier=1,
- ksize=3,
- stride=1,
- pad_mode="same",
- dilation=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.decoder_pointwise1 = _conv_bn_relu(depth,
- depth,
- ksize=1,
- stride=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.depth = depth
- self.concat = P.Concat(axis=1)
- self.samples = []
- for scale_size in scale_sizes:
- self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride))
- self.samples = nn.CellList(self.samples)
-
- def construct(self, x, low_level_feature, scale_index):
- low_level_feature = self.feature_projection(low_level_feature)
- low_level_feature = self.samples[scale_index](low_level_feature)
- x = self.samples[scale_index](x)
- output = self.concat((x, low_level_feature))
- output = self.decoder_depth0(output)
- output = self.decoder_pointwise0(output)
- output = self.decoder_depth1(output)
- output = self.decoder_pointwise1(output)
- return output
-
-
- class SingleDeepLabV3(nn.Cell):
- """
- DeepLabv3 Network.
- Args:
- num_classes (int): Class number.
- feature_shape (list): Input image shape, [N,C,H,W].
- backbone (Cell): Backbone Network.
- channel (int): Resnet output channel.
- depth (int): ASPP block depth.
- scale_sizes (list): Input scales for multi-scale feature extraction.
- atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
- decoder_output_stride (int): 'The ratio of input to output spatial resolution'
- output_stride (int): 'The ratio of input to output spatial resolution.'
- fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
- Returns:
- Tensor, output tensor.
- Examples:
- >>> SingleDeepLabV3(num_classes=10,
- >>> feature_shape=[1,3,224,224],
- >>> backbone=resnet50_dl(),
- >>> channel=2048,
- >>> depth=256)
- >>> scale_sizes=[1.0])
- >>> atrous_rates=[6])
- >>> decoder_output_stride=4)
- >>> output_stride=16)
- """
-
- def __init__(self,
- num_classes,
- feature_shape,
- backbone,
- channel,
- depth,
- scale_sizes,
- atrous_rates,
- decoder_output_stride,
- output_stride,
- fine_tune_batch_norm=False):
- super(SingleDeepLabV3, self).__init__()
- self.num_classes = num_classes
- self.channel = channel
- self.depth = depth
- self.scale_sizes = []
- for scale_size in np.sort(scale_sizes):
- self.scale_sizes.append(scale_size)
- self.net = backbone
- self.aspp = ASPP(channel=self.channel,
- depth=self.depth,
- feature_shape=[feature_shape[2],
- feature_shape[3]],
- scale_sizes=self.scale_sizes,
- atrous_rates=atrous_rates,
- output_stride=output_stride,
- fine_tune_batch_norm=fine_tune_batch_norm)
- self.aspp.add_flags(loop_can_unroll=True)
- atrous_rates_len = 0
- if atrous_rates is not None:
- atrous_rates_len = len(atrous_rates)
- self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth,
- ksize=1,
- stride=1,
- use_batch_statistics=fine_tune_batch_norm)
- self.fc2 = nn.Conv2d(depth,
- num_classes,
- kernel_size=1,
- stride=1,
- has_bias=True)
- self.upsample = P.ResizeBilinear((int(feature_shape[2]),
- int(feature_shape[3])),
- align_corners=True)
- self.samples = []
- for scale_size in self.scale_sizes:
- self.samples.append(SampleBlock(feature_shape, scale_size))
- self.samples = nn.CellList(self.samples)
- self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]),
- float(feature_shape[3])]
-
- self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
- self.dropout = nn.Dropout(keep_prob=0.9)
- self.shape = P.Shape()
- self.decoder_output_stride = decoder_output_stride
- if decoder_output_stride is not None:
- self.decoder = Decoder(low_level_channel=depth,
- channel=depth,
- depth=depth,
- feature_shape=[feature_shape[2],
- feature_shape[3]],
- scale_sizes=self.scale_sizes,
- decoder_output_stride=decoder_output_stride,
- fine_tune_batch_norm=fine_tune_batch_norm)
-
- def construct(self, x, scale_index=0):
- x = (2.0 / 255.0) * x - 1.0
- x = self.pad(x)
- low_level_feature, feature_map = self.net(x)
- for scale_size in self.scale_sizes:
- if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2:
- output = self.aspp(feature_map, scale_index)
- output = self.fc1(output)
- if self.decoder_output_stride is not None:
- output = self.decoder(output, low_level_feature, scale_index)
- output = self.fc2(output)
- output = self.samples[scale_index](output)
- return output
- scale_index += 1
- return feature_map
-
-
- class SampleBlock(nn.Cell):
- """Sample block."""
- def __init__(self,
- feature_shape,
- scale_size=1.0):
- super(SampleBlock, self).__init__()
- sample_h = np.ceil(float(feature_shape[2]) * scale_size)
- sample_w = np.ceil(float(feature_shape[3]) * scale_size)
- self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
-
- def construct(self, x):
- return self.sample(x)
-
-
- class DeepLabV3(nn.Cell):
- """DeepLabV3 model."""
- def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates,
- decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid):
- super(DeepLabV3, self).__init__()
- self.infer_scale_sizes = []
- if infer_scale_sizes is not None:
- self.infer_scale_sizes = infer_scale_sizes
-
- self.infer_scale_sizes = infer_scale_sizes
- if image_pyramid is None:
- image_pyramid = [1.0]
-
- self.image_pyramid = image_pyramid
- scale_sizes = []
- for pyramid in image_pyramid:
- scale_sizes.append(pyramid)
- for scale in infer_scale_sizes:
- scale_sizes.append(scale)
- self.samples = []
- for scale_size in scale_sizes:
- self.samples.append(SampleBlock(feature_shape, scale_size))
- self.samples = nn.CellList(self.samples)
- self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes,
- feature_shape=feature_shape,
- backbone=resnet50_dl(fine_tune_batch_norm),
- channel=channel,
- depth=depth,
- scale_sizes=scale_sizes,
- atrous_rates=atrous_rates,
- decoder_output_stride=decoder_output_stride,
- output_stride=output_stride,
- fine_tune_batch_norm=fine_tune_batch_norm)
- self.softmax = P.Softmax(axis=1)
- self.concat = P.Concat(axis=2)
- self.expand_dims = P.ExpandDims()
- self.reduce_mean = P.ReduceMean()
- self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
- int(feature_shape[3])),
- align_corners=True)
-
- def construct(self, x):
- logits = ()
- if self.training:
- if len(self.image_pyramid) >= 1:
- if self.image_pyramid[0] == 1:
- logits = self.deeplabv3(x)
- else:
- x1 = self.samples[0](x)
- logits = self.deeplabv3(x1)
- logits = self.sample_common(logits)
- logits = self.expand_dims(logits, 2)
- for i in range(len(self.image_pyramid) - 1):
- x_i = self.samples[i + 1](x)
- logits_i = self.deeplabv3(x_i)
- logits_i = self.sample_common(logits_i)
- logits_i = self.expand_dims(logits_i, 2)
- logits = self.concat((logits, logits_i))
- logits = self.reduce_mean(logits, 2)
- return logits
- if len(self.infer_scale_sizes) >= 1:
- infer_index = len(self.image_pyramid)
- x1 = self.samples[infer_index](x)
- logits = self.deeplabv3(x1)
- logits = self.sample_common(logits)
- logits = self.softmax(logits)
- logits = self.expand_dims(logits, 2)
- for i in range(len(self.infer_scale_sizes) - 1):
- x_i = self.samples[i + 1 + infer_index](x)
- logits_i = self.deeplabv3(x_i)
- logits_i = self.sample_common(logits_i)
- logits_i = self.softmax(logits_i)
- logits_i = self.expand_dims(logits_i, 2)
- logits = self.concat((logits, logits_i))
- logits = self.reduce_mean(logits, 2)
- return logits
-
-
- def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid,
- infer_scale_sizes, atrous_rates=None, decoder_output_stride=None,
- output_stride=16, fine_tune_batch_norm=False):
- """
- ResNet50 based DeepLabv3 network.
-
- Args:
- num_classes (int): Class number.
- feature_shape (list): Input image shape, [N,C,H,W].
- image_pyramid (list): Input scales for multi-scale feature extraction.
- atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
- infer_scale_sizes (list): 'The scales to resize images for inference.
- decoder_output_stride (int): 'The ratio of input to output spatial resolution'
- output_stride (int): 'The ratio of input to output spatial resolution.'
- fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
-
- Returns:
- Cell, cell instance of ResNet50 based DeepLabv3 neural network.
-
- Examples:
- >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0])
- """
- return DeepLabV3(num_classes=num_classes,
- feature_shape=feature_shape,
- backbone=resnet50_dl(fine_tune_batch_norm),
- channel=2048,
- depth=256,
- infer_scale_sizes=infer_scale_sizes,
- atrous_rates=atrous_rates,
- decoder_output_stride=decoder_output_stride,
- output_stride=output_stride,
- fine_tune_batch_norm=fine_tune_batch_norm,
- image_pyramid=image_pyramid)
|