Browse Source

fix pylint problem

tags/v0.5.0-beta
yangyongjie 5 years ago
parent
commit
52c59900a7
5 changed files with 117 additions and 66 deletions
  1. +1
    -2
      model_zoo/deeplabv3/src/__init__.py
  2. +18
    -5
      model_zoo/deeplabv3/src/backbone/__init__.py
  3. +96
    -56
      model_zoo/deeplabv3/src/backbone/resnet_deeplab.py
  4. +1
    -1
      model_zoo/deeplabv3/src/config.py
  5. +1
    -2
      model_zoo/deeplabv3/src/utils/__init__.py

+ 1
- 2
model_zoo/deeplabv3/src/__init__.py View File

@@ -14,11 +14,10 @@
# ============================================================================
"""Init DeepLabv3."""
from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50
from . import backbone
from .backbone import *

__all__ = [
"ASPP", "DeepLabV3", "deeplabv3_resnet50", "Decoder"
"ASPP", "DeepLabV3", "deeplabv3_resnet50"
]

__all__.extend(backbone.__all__)

+ 18
- 5
model_zoo/deeplabv3/src/backbone/__init__.py View File

@@ -1,8 +1,21 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init backbone."""
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
RootBlockBeta, resnet50_dl
__all__= [
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta",
"resnet50_dl"

__all__ = [
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", "resnet50_dl"
]

+ 96
- 56
model_zoo/deeplabv3/src/backbone/resnet_deeplab.py View File

@@ -1,4 +1,3 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,12 +15,11 @@
"""ResNet based DeepLab."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
import numpy as np
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore._checkparam import check_bool, twice
from mindspore import log as logger
from mindspore.common.initializer import initializer
from mindspore._checkparam import twice
from mindspore.common.parameter import Parameter


def _conv_bn_relu(in_channel,
out_channel,
ksize,
@@ -42,52 +40,58 @@ def _conv_bn_relu(in_channel,
nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)


def _deep_conv_bn_relu(in_channel,
channel_multiplier,
ksize,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
channel_multiplier,
ksize,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
return nn.SequentialCell(
[DepthwiseConv2dNative(in_channel,
channel_multiplier,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
channel_multiplier,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)


def _stob_deep_conv_btos_bn_relu(in_channel,
channel_multiplier,
ksize,
space_to_batch_block_shape,
batch_to_space_block_shape,
paddings,
crops,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
channel_multiplier,
ksize,
space_to_batch_block_shape,
batch_to_space_block_shape,
paddings,
crops,
stride=1,
padding=0,
dilation=1,
pad_mode="pad",
use_batch_statistics=False):
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer"""
return nn.SequentialCell(
[SpaceToBatch(space_to_batch_block_shape,paddings),
[SpaceToBatch(space_to_batch_block_shape, paddings),
DepthwiseConv2dNative(in_channel,
channel_multiplier,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
BatchToSpace(batch_to_space_block_shape,crops),
channel_multiplier,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
pad_mode=pad_mode),
BatchToSpace(batch_to_space_block_shape, crops),
nn.BatchNorm2d(channel_multiplier * in_channel, use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)
)


def _stob_conv_btos_bn_relu(in_channel,
out_channel,
ksize,
@@ -114,6 +118,8 @@ def _stob_conv_btos_bn_relu(in_channel,
nn.BatchNorm2d(out_channel,use_batch_statistics=use_batch_statistics),
nn.ReLU()]
)


def _make_layer(block,
in_channels,
out_channels,
@@ -153,6 +159,8 @@ def _make_layer(block,
in_channels = out_channels
layer = nn.SequentialCell(blocks)
return layer, g_current_stride, g_rate


class Subsample(nn.Cell):
"""
Subsample for DeepLab-ResNet.
@@ -168,26 +176,35 @@ class Subsample(nn.Cell):
self.factor = factor
self.pool = nn.MaxPool2d(kernel_size=1,
stride=factor)

def construct(self, x):
if self.factor == 1:
return x
return self.pool(x)


class SpaceToBatch(nn.Cell):
def __init__(self, block_shape, paddings):
super(SpaceToBatch, self).__init__()
self.space_to_batch = P.SpaceToBatch(block_shape, paddings)
self.bs = block_shape
self.pd = paddings

def construct(self, x):
return self.space_to_batch(x)


class BatchToSpace(nn.Cell):
def __init__(self, block_shape, crops):
super(BatchToSpace, self).__init__()
self.batch_to_space = P.BatchToSpace(block_shape, crops)
self.bs = block_shape
self.cr = crops

def construct(self, x):
return self.batch_to_space(x)


class _DepthwiseConv2dNative(nn.Cell):
def __init__(self,
in_channels,
@@ -218,9 +235,12 @@ class _DepthwiseConv2dNative(nn.Cell):
+ str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.')
self.weight = Parameter(initializer(weight_init, [1, in_channels // group, *kernel_size]),
name='weight')

def construct(self, *inputs):
"""Must be overridden by all subclasses."""
raise NotImplementedError


class DepthwiseConv2dNative(_DepthwiseConv2dNative):
def __init__(self,
in_channels,
@@ -244,18 +264,22 @@ class DepthwiseConv2dNative(_DepthwiseConv2dNative):
group,
weight_init)
self.depthwise_conv2d_native = P.DepthwiseConv2dNative(channel_multiplier=self.channel_multiplier,
kernel_size=self.kernel_size,
mode=3,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
kernel_size=self.kernel_size,
mode=3,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)

def set_strategy(self, strategy):
self.depthwise_conv2d_native.set_strategy(strategy)
return self

def construct(self, x):
return self.depthwise_conv2d_native(x, self.weight)


class BottleneckV1(nn.Cell):
"""
ResNet V1 BottleneckV1 block definition.
@@ -322,6 +346,7 @@ class BottleneckV1(nn.Cell):
self.add = P.TensorAdd()
self.relu = nn.ReLU()
self.Reshape = P.Reshape()

def construct(self, x):
out = self.conv_bn1(x)
out = self.conv_bn2(out)
@@ -329,8 +354,8 @@ class BottleneckV1(nn.Cell):
out = self.add(out, self.downsample(x))
out = self.relu(out)
return out
return out
class BottleneckV2(nn.Cell):
"""
ResNet V2 Bottleneck variance V2 block definition.
@@ -365,7 +390,7 @@ class BottleneckV2(nn.Cell):
padding=1,
dilation=dilation,
use_batch_statistics=use_batch_statistics)
if use_batch_to_stob_and_btos == True:
if use_batch_to_stob_and_btos:
self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels,
mid_channels,
ksize=3,
@@ -394,6 +419,7 @@ class BottleneckV2(nn.Cell):
self.downsample = Subsample(stride)
self.add = P.TensorAdd()
self.relu = nn.ReLU()

def construct(self, x):
out = self.conv_bn1(x)
out = self.conv_bn2(out)
@@ -402,6 +428,7 @@ class BottleneckV2(nn.Cell):
out = self.relu(out)
return out


class BottleneckV3(nn.Cell):
"""
ResNet V1 Bottleneck variance V1 block definition.
@@ -452,6 +479,7 @@ class BottleneckV3(nn.Cell):
self.downsample = Subsample(stride)
self.add = P.TensorAdd()
self.relu = nn.ReLU()

def construct(self, x):
out = self.conv_bn1(x)
out = self.conv_bn2(out)
@@ -460,6 +488,7 @@ class BottleneckV3(nn.Cell):
out = self.relu(out)
return out


class ResNetV1(nn.Cell):
"""
ResNet V1 for DeepLab.
@@ -491,9 +520,13 @@ class ResNetV1(nn.Cell):
self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm)
self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True, use_batch_statistics=fine_tune_batch_norm)
self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, use_batch_statistics=fine_tune_batch_norm)
self.layer4_3 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, use_batch_statistics=fine_tune_batch_norm)
self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True,
use_batch_statistics=fine_tune_batch_norm)
self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True,
use_batch_statistics=fine_tune_batch_norm)
self.layer4_3 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True,
use_batch_statistics=fine_tune_batch_norm)

def construct(self, x):
x = self.layer_root(x)
x = self.layer1_1(x)
@@ -514,6 +547,8 @@ class ResNetV1(nn.Cell):
x = self.layer4_2(x)
c5 = self.layer4_3(x)
return c2, c5


class RootBlockBeta(nn.Cell):
"""
ResNet V1 beta root block definition.
@@ -524,14 +559,19 @@ class RootBlockBeta(nn.Cell):
"""
def __init__(self, fine_tune_batch_norm=False):
super(RootBlockBeta, self).__init__()
self.conv1 = _conv_bn_relu(3, 64, ksize=3, stride=2, padding=0, pad_mode="valid", use_batch_statistics=fine_tune_batch_norm)
self.conv2 = _conv_bn_relu(64, 64, ksize=3, stride=1, padding=0, pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
self.conv3 = _conv_bn_relu(64, 128, ksize=3, stride=1, padding=0, pad_m ode="same", use_batch_statistics=fine_tune_batch_norm)
self.conv1 = _conv_bn_relu(3, 64, ksize=3, stride=2, padding=0, pad_mode="valid",
use_batch_statistics=fine_tune_batch_norm)
self.conv2 = _conv_bn_relu(64, 64, ksize=3, stride=1, padding=0, pad_mode="same",
use_batch_statistics=fine_tune_batch_norm)
self.conv3 = _conv_bn_relu(64, 128, ksize=3, stride=1, padding=0, pad_mode="same",
use_batch_statistics=fine_tune_batch_norm)

def construct(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x


def resnet50_dl(fine_tune_batch_norm=False):
return ResNetV1(fine_tune_batch_norm)

+ 1
- 1
model_zoo/deeplabv3/src/config.py View File

@@ -30,4 +30,4 @@ config = ed({
"ignore_label": 255,
"decoder_output_stride": None,
"seg_num_classes": 21
})
})

+ 1
- 2
model_zoo/deeplabv3/src/utils/__init__.py View File

@@ -1,4 +1,3 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# ============================================================================

Loading…
Cancel
Save