|
|
|
@@ -93,31 +93,30 @@ def _stob_deep_conv_btos_bn_relu(in_channel, |
|
|
|
|
|
|
|
|
|
|
|
def _stob_conv_btos_bn_relu(in_channel, |
|
|
|
out_channel, |
|
|
|
ksize, |
|
|
|
space_to_batch_block_shape, |
|
|
|
batch_to_space_block_shape, |
|
|
|
paddings, |
|
|
|
crops, |
|
|
|
stride=1, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
pad_mode="pad", |
|
|
|
use_batch_statistics=False): |
|
|
|
out_channel, |
|
|
|
ksize, |
|
|
|
space_to_batch_block_shape, |
|
|
|
batch_to_space_block_shape, |
|
|
|
paddings, |
|
|
|
crops, |
|
|
|
stride=1, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
pad_mode="pad", |
|
|
|
use_batch_statistics=False): |
|
|
|
"""Get a spacetobatch -> conv2d -> batchnorm -> relu -> batchtospace layer""" |
|
|
|
return nn.SequentialCell( |
|
|
|
[SpaceToBatch(space_to_batch_block_shape,paddings), |
|
|
|
nn.Conv2d(in_channel, |
|
|
|
out_channel, |
|
|
|
kernel_size=ksize, |
|
|
|
stride=stride, |
|
|
|
padding=padding, |
|
|
|
dilation=dilation, |
|
|
|
pad_mode=pad_mode), |
|
|
|
BatchToSpace(batch_to_space_block_shape,crops), |
|
|
|
nn.BatchNorm2d(out_channel,use_batch_statistics=use_batch_statistics), |
|
|
|
nn.ReLU()] |
|
|
|
) |
|
|
|
return nn.SequentialCell([SpaceToBatch(space_to_batch_block_shape, paddings), |
|
|
|
nn.Conv2d(in_channel, |
|
|
|
out_channel, |
|
|
|
kernel_size=ksize, |
|
|
|
stride=stride, |
|
|
|
padding=padding, |
|
|
|
dilation=dilation, |
|
|
|
pad_mode=pad_mode), |
|
|
|
BatchToSpace(batch_to_space_block_shape, crops), |
|
|
|
nn.BatchNorm2d(out_channel, use_batch_statistics=use_batch_statistics), |
|
|
|
nn.ReLU()] |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _make_layer(block, |
|
|
|
@@ -206,6 +205,7 @@ class BatchToSpace(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
class _DepthwiseConv2dNative(nn.Cell): |
|
|
|
"""Depthwise Conv2D Cell.""" |
|
|
|
def __init__(self, |
|
|
|
in_channels, |
|
|
|
channel_multiplier, |
|
|
|
@@ -242,6 +242,7 @@ class _DepthwiseConv2dNative(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
class DepthwiseConv2dNative(_DepthwiseConv2dNative): |
|
|
|
"""Depthwise Conv2D Cell.""" |
|
|
|
def __init__(self, |
|
|
|
in_channels, |
|
|
|
channel_multiplier, |
|
|
|
@@ -315,31 +316,31 @@ class BottleneckV1(nn.Cell): |
|
|
|
padding=1, |
|
|
|
dilation=1, |
|
|
|
use_batch_statistics=use_batch_statistics) |
|
|
|
if use_batch_to_stob_and_btos == True: |
|
|
|
if use_batch_to_stob_and_btos: |
|
|
|
self.conv_bn2 = _stob_conv_btos_bn_relu(mid_channels, |
|
|
|
mid_channels, |
|
|
|
ksize=3, |
|
|
|
stride=stride, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
space_to_batch_block_shape = 2, |
|
|
|
batch_to_space_block_shape = 2, |
|
|
|
paddings =[[2, 3], [2, 3]], |
|
|
|
crops =[[0, 1], [0, 1]], |
|
|
|
space_to_batch_block_shape=2, |
|
|
|
batch_to_space_block_shape=2, |
|
|
|
paddings=[[2, 3], [2, 3]], |
|
|
|
crops=[[0, 1], [0, 1]], |
|
|
|
pad_mode="valid", |
|
|
|
use_batch_statistics=use_batch_statistics) |
|
|
|
|
|
|
|
|
|
|
|
self.conv3 = nn.Conv2d(mid_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size=1, |
|
|
|
stride=1) |
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) |
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) |
|
|
|
if in_channels != out_channels: |
|
|
|
conv = nn.Conv2d(in_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size=1, |
|
|
|
stride=stride) |
|
|
|
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) |
|
|
|
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) |
|
|
|
self.downsample = nn.SequentialCell([conv, bn]) |
|
|
|
else: |
|
|
|
self.downsample = Subsample(stride) |
|
|
|
@@ -397,23 +398,23 @@ class BottleneckV2(nn.Cell): |
|
|
|
stride=stride, |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
space_to_batch_block_shape = 2, |
|
|
|
batch_to_space_block_shape = 2, |
|
|
|
paddings =[[2, 3], [2, 3]], |
|
|
|
crops =[[0, 1], [0, 1]], |
|
|
|
space_to_batch_block_shape=2, |
|
|
|
batch_to_space_block_shape=2, |
|
|
|
paddings=[[2, 3], [2, 3]], |
|
|
|
crops=[[0, 1], [0, 1]], |
|
|
|
pad_mode="valid", |
|
|
|
use_batch_statistics=use_batch_statistics) |
|
|
|
self.conv3 = nn.Conv2d(mid_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size=1, |
|
|
|
stride=1) |
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) |
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) |
|
|
|
if in_channels != out_channels: |
|
|
|
conv = nn.Conv2d(in_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size=1, |
|
|
|
stride=stride) |
|
|
|
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) |
|
|
|
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) |
|
|
|
self.downsample = nn.SequentialCell([conv, bn]) |
|
|
|
else: |
|
|
|
self.downsample = Subsample(stride) |
|
|
|
@@ -465,14 +466,14 @@ class BottleneckV3(nn.Cell): |
|
|
|
out_channels, |
|
|
|
kernel_size=1, |
|
|
|
stride=1) |
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) |
|
|
|
|
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) |
|
|
|
|
|
|
|
if in_channels != out_channels: |
|
|
|
conv = nn.Conv2d(in_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size=1, |
|
|
|
stride=stride) |
|
|
|
bn = nn.BatchNorm2d(out_channels,use_batch_statistics=use_batch_statistics) |
|
|
|
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics) |
|
|
|
self.downsample = nn.SequentialCell([conv, bn]) |
|
|
|
else: |
|
|
|
self.downsample = Subsample(stride) |
|
|
|
@@ -502,9 +503,8 @@ class ResNetV1(nn.Cell): |
|
|
|
super(ResNetV1, self).__init__() |
|
|
|
self.layer_root = nn.SequentialCell( |
|
|
|
[RootBlockBeta(fine_tune_batch_norm), |
|
|
|
nn.MaxPool2d(kernel_size=(3,3), |
|
|
|
stride=(2,2), |
|
|
|
#padding=1, |
|
|
|
nn.MaxPool2d(kernel_size=(3, 3), |
|
|
|
stride=(2, 2), |
|
|
|
pad_mode='same')]) |
|
|
|
self.layer1_1 = BottleneckV1(128, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) |
|
|
|
self.layer1_2 = BottleneckV2(256, 256, stride=1, use_batch_statistics=fine_tune_batch_norm) |
|
|
|
@@ -519,7 +519,7 @@ class ResNetV1(nn.Cell): |
|
|
|
self.layer3_4 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) |
|
|
|
self.layer3_5 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) |
|
|
|
self.layer3_6 = BottleneckV2(1024, 1024, stride=1, use_batch_statistics=fine_tune_batch_norm) |
|
|
|
|
|
|
|
|
|
|
|
self.layer4_1 = BottleneckV1(1024, 2048, stride=1, use_batch_to_stob_and_btos=True, |
|
|
|
use_batch_statistics=fine_tune_batch_norm) |
|
|
|
self.layer4_2 = BottleneckV2(2048, 2048, stride=1, use_batch_to_stob_and_btos=True, |
|
|
|
@@ -542,7 +542,7 @@ class ResNetV1(nn.Cell): |
|
|
|
x = self.layer3_4(x) |
|
|
|
x = self.layer3_5(x) |
|
|
|
x = self.layer3_6(x) |
|
|
|
|
|
|
|
|
|
|
|
x = self.layer4_1(x) |
|
|
|
x = self.layer4_2(x) |
|
|
|
c5 = self.layer4_3(x) |
|
|
|
|