You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deeplabv3.py 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the License);
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # httpwww.apache.orglicensesLICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an AS IS BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """DeepLabv3."""
  16. import numpy as np
  17. import mindspore.nn as nn
  18. from mindspore.ops import operations as P
  19. from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
  20. DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
  21. class ASPPSampleBlock(nn.Cell):
  22. """ASPP sample block."""
  23. def __init__(self, feature_shape, scale_size, output_stride):
  24. super(ASPPSampleBlock, self).__init__()
  25. sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1
  26. sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1
  27. self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
  28. def construct(self, x):
  29. return self.sample(x)
  30. class ASPP(nn.Cell):
  31. """
  32. ASPP model for DeepLabv3.
  33. Args:
  34. channel (int): Input channel.
  35. depth (int): Output channel.
  36. feature_shape (list): The shape of feature,[h,w].
  37. scale_sizes (list): Input scales for multi-scale feature extraction.
  38. atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
  39. output_stride (int): 'The ratio of input to output spatial resolution.'
  40. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  41. Returns:
  42. Tensor, output tensor.
  43. Examples:
  44. >>> ASPP(channel=2048,256,[14,14],[1],[6],16)
  45. """
  46. def __init__(self, channel, depth, feature_shape, scale_sizes,
  47. atrous_rates, output_stride, fine_tune_batch_norm=False):
  48. super(ASPP, self).__init__()
  49. self.aspp0 = _conv_bn_relu(channel,
  50. depth,
  51. ksize=1,
  52. stride=1,
  53. use_batch_statistics=fine_tune_batch_norm)
  54. self.atrous_rates = []
  55. if atrous_rates is not None:
  56. self.atrous_rates = atrous_rates
  57. self.aspp_pointwise = _conv_bn_relu(channel,
  58. depth,
  59. ksize=1,
  60. stride=1,
  61. use_batch_statistics=fine_tune_batch_norm)
  62. self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel,
  63. channel_multiplier=1,
  64. kernel_size=3,
  65. stride=1,
  66. dilation=1,
  67. pad_mode="valid")
  68. self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm)
  69. self.aspp_depth_relu = nn.ReLU()
  70. self.aspp_depths = []
  71. self.aspp_depth_spacetobatchs = []
  72. self.aspp_depth_batchtospaces = []
  73. for scale_size in scale_sizes:
  74. aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16)
  75. if atrous_rates is None:
  76. break
  77. for rate in atrous_rates:
  78. padding = 0
  79. for j in range(100):
  80. padded_size = rate * j
  81. if padded_size >= aspp_scale_depth_size + 2 * rate:
  82. padding = padded_size - aspp_scale_depth_size - 2 * rate
  83. break
  84. paddings = [[rate, rate + int(padding)],
  85. [rate, rate + int(padding)]]
  86. self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings)
  87. self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch)
  88. crops = [[0, int(padding)], [0, int(padding)]]
  89. self.aspp_depth_batchtospace = BatchToSpace(rate, crops)
  90. self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace)
  91. self.aspp_depths = nn.CellList(self.aspp_depths)
  92. self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs)
  93. self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces)
  94. self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1])))
  95. self.global_poolings = []
  96. for scale_size in scale_sizes:
  97. pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride)
  98. pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride)
  99. self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w))))
  100. self.global_poolings = nn.CellList(self.global_poolings)
  101. self.conv_bn = _conv_bn_relu(channel,
  102. depth,
  103. ksize=1,
  104. stride=1,
  105. use_batch_statistics=fine_tune_batch_norm)
  106. self.samples = []
  107. for scale_size in scale_sizes:
  108. self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride))
  109. self.samples = nn.CellList(self.samples)
  110. self.feature_shape = feature_shape
  111. self.concat = P.Concat(axis=1)
  112. def construct(self, x, scale_index=0):
  113. aspp0 = self.aspp0(x)
  114. aspp1 = self.global_poolings[scale_index](x)
  115. aspp1 = self.conv_bn(aspp1)
  116. aspp1 = self.samples[scale_index](aspp1)
  117. output = self.concat((aspp1, aspp0))
  118. for i in range(len(self.atrous_rates)):
  119. aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x)
  120. aspp_i = self.aspp_depth_depthwiseconv(aspp_i)
  121. aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i)
  122. aspp_i = self.aspp_depth_bn(aspp_i)
  123. aspp_i = self.aspp_depth_relu(aspp_i)
  124. aspp_i = self.aspp_pointwise(aspp_i)
  125. output = self.concat((output, aspp_i))
  126. return output
  127. class DecoderSampleBlock(nn.Cell):
  128. """Decoder sample block."""
  129. def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4):
  130. super(DecoderSampleBlock, self).__init__()
  131. sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1
  132. sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1
  133. self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
  134. def construct(self, x):
  135. return self.sample(x)
  136. class Decoder(nn.Cell):
  137. """
  138. Decode module for DeepLabv3.
  139. Args:
  140. low_level_channel (int): Low level input channel
  141. channel (int): Input channel.
  142. depth (int): Output channel.
  143. feature_shape (list): 'Input image shape, [N,C,H,W].'
  144. scale_sizes (list): 'Input scales for multi-scale feature extraction.'
  145. decoder_output_stride (int): 'The ratio of input to output spatial resolution'
  146. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  147. Returns:
  148. Tensor, output tensor.
  149. Examples:
  150. >>> Decoder(256, 100, [56,56])
  151. """
  152. def __init__(self,
  153. low_level_channel,
  154. channel,
  155. depth,
  156. feature_shape,
  157. scale_sizes,
  158. decoder_output_stride,
  159. fine_tune_batch_norm):
  160. super(Decoder, self).__init__()
  161. self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1,
  162. pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
  163. self.decoder_depth0 = _deep_conv_bn_relu(channel + 48,
  164. channel_multiplier=1,
  165. ksize=3,
  166. stride=1,
  167. pad_mode="same",
  168. dilation=1,
  169. use_batch_statistics=fine_tune_batch_norm)
  170. self.decoder_pointwise0 = _conv_bn_relu(channel + 48,
  171. depth,
  172. ksize=1,
  173. stride=1,
  174. use_batch_statistics=fine_tune_batch_norm)
  175. self.decoder_depth1 = _deep_conv_bn_relu(depth,
  176. channel_multiplier=1,
  177. ksize=3,
  178. stride=1,
  179. pad_mode="same",
  180. dilation=1,
  181. use_batch_statistics=fine_tune_batch_norm)
  182. self.decoder_pointwise1 = _conv_bn_relu(depth,
  183. depth,
  184. ksize=1,
  185. stride=1,
  186. use_batch_statistics=fine_tune_batch_norm)
  187. self.depth = depth
  188. self.concat = P.Concat(axis=1)
  189. self.samples = []
  190. for scale_size in scale_sizes:
  191. self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride))
  192. self.samples = nn.CellList(self.samples)
  193. def construct(self, x, low_level_feature, scale_index):
  194. low_level_feature = self.feature_projection(low_level_feature)
  195. low_level_feature = self.samples[scale_index](low_level_feature)
  196. x = self.samples[scale_index](x)
  197. output = self.concat((x, low_level_feature))
  198. output = self.decoder_depth0(output)
  199. output = self.decoder_pointwise0(output)
  200. output = self.decoder_depth1(output)
  201. output = self.decoder_pointwise1(output)
  202. return output
  203. class SingleDeepLabV3(nn.Cell):
  204. """
  205. DeepLabv3 Network.
  206. Args:
  207. num_classes (int): Class number.
  208. feature_shape (list): Input image shape, [N,C,H,W].
  209. backbone (Cell): Backbone Network.
  210. channel (int): Resnet output channel.
  211. depth (int): ASPP block depth.
  212. scale_sizes (list): Input scales for multi-scale feature extraction.
  213. atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
  214. decoder_output_stride (int): 'The ratio of input to output spatial resolution'
  215. output_stride (int): 'The ratio of input to output spatial resolution.'
  216. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  217. Returns:
  218. Tensor, output tensor.
  219. Examples:
  220. >>> SingleDeepLabV3(num_classes=10,
  221. >>> feature_shape=[1,3,224,224],
  222. >>> backbone=resnet50_dl(),
  223. >>> channel=2048,
  224. >>> depth=256)
  225. >>> scale_sizes=[1.0])
  226. >>> atrous_rates=[6])
  227. >>> decoder_output_stride=4)
  228. >>> output_stride=16)
  229. """
  230. def __init__(self,
  231. num_classes,
  232. feature_shape,
  233. backbone,
  234. channel,
  235. depth,
  236. scale_sizes,
  237. atrous_rates,
  238. decoder_output_stride,
  239. output_stride,
  240. fine_tune_batch_norm=False):
  241. super(SingleDeepLabV3, self).__init__()
  242. self.num_classes = num_classes
  243. self.channel = channel
  244. self.depth = depth
  245. self.scale_sizes = []
  246. for scale_size in np.sort(scale_sizes):
  247. self.scale_sizes.append(scale_size)
  248. self.net = backbone
  249. self.aspp = ASPP(channel=self.channel,
  250. depth=self.depth,
  251. feature_shape=[feature_shape[2],
  252. feature_shape[3]],
  253. scale_sizes=self.scale_sizes,
  254. atrous_rates=atrous_rates,
  255. output_stride=output_stride,
  256. fine_tune_batch_norm=fine_tune_batch_norm)
  257. self.aspp.add_flags(loop_can_unroll=True)
  258. atrous_rates_len = 0
  259. if atrous_rates is not None:
  260. atrous_rates_len = len(atrous_rates)
  261. self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth,
  262. ksize=1,
  263. stride=1,
  264. use_batch_statistics=fine_tune_batch_norm)
  265. self.fc2 = nn.Conv2d(depth,
  266. num_classes,
  267. kernel_size=1,
  268. stride=1,
  269. has_bias=True)
  270. self.upsample = P.ResizeBilinear((int(feature_shape[2]),
  271. int(feature_shape[3])),
  272. align_corners=True)
  273. self.samples = []
  274. for scale_size in self.scale_sizes:
  275. self.samples.append(SampleBlock(feature_shape, scale_size))
  276. self.samples = nn.CellList(self.samples)
  277. self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]),
  278. float(feature_shape[3])]
  279. self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
  280. self.dropout = nn.Dropout(keep_prob=0.9)
  281. self.shape = P.Shape()
  282. self.decoder_output_stride = decoder_output_stride
  283. if decoder_output_stride is not None:
  284. self.decoder = Decoder(low_level_channel=depth,
  285. channel=depth,
  286. depth=depth,
  287. feature_shape=[feature_shape[2],
  288. feature_shape[3]],
  289. scale_sizes=self.scale_sizes,
  290. decoder_output_stride=decoder_output_stride,
  291. fine_tune_batch_norm=fine_tune_batch_norm)
  292. def construct(self, x, scale_index=0):
  293. x = (2.0 / 255.0) * x - 1.0
  294. x = self.pad(x)
  295. low_level_feature, feature_map = self.net(x)
  296. for scale_size in self.scale_sizes:
  297. if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2:
  298. output = self.aspp(feature_map, scale_index)
  299. output = self.fc1(output)
  300. if self.decoder_output_stride is not None:
  301. output = self.decoder(output, low_level_feature, scale_index)
  302. output = self.fc2(output)
  303. output = self.samples[scale_index](output)
  304. return output
  305. scale_index += 1
  306. return feature_map
  307. class SampleBlock(nn.Cell):
  308. """Sample block."""
  309. def __init__(self,
  310. feature_shape,
  311. scale_size=1.0):
  312. super(SampleBlock, self).__init__()
  313. sample_h = np.ceil(float(feature_shape[2]) * scale_size)
  314. sample_w = np.ceil(float(feature_shape[3]) * scale_size)
  315. self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
  316. def construct(self, x):
  317. return self.sample(x)
  318. class DeepLabV3(nn.Cell):
  319. """DeepLabV3 model."""
  320. def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates,
  321. decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid):
  322. super(DeepLabV3, self).__init__()
  323. self.infer_scale_sizes = []
  324. if infer_scale_sizes is not None:
  325. self.infer_scale_sizes = infer_scale_sizes
  326. self.infer_scale_sizes = infer_scale_sizes
  327. if image_pyramid is None:
  328. image_pyramid = [1.0]
  329. self.image_pyramid = image_pyramid
  330. scale_sizes = []
  331. for pyramid in image_pyramid:
  332. scale_sizes.append(pyramid)
  333. for scale in infer_scale_sizes:
  334. scale_sizes.append(scale)
  335. self.samples = []
  336. for scale_size in scale_sizes:
  337. self.samples.append(SampleBlock(feature_shape, scale_size))
  338. self.samples = nn.CellList(self.samples)
  339. self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes,
  340. feature_shape=feature_shape,
  341. backbone=resnet50_dl(fine_tune_batch_norm),
  342. channel=channel,
  343. depth=depth,
  344. scale_sizes=scale_sizes,
  345. atrous_rates=atrous_rates,
  346. decoder_output_stride=decoder_output_stride,
  347. output_stride=output_stride,
  348. fine_tune_batch_norm=fine_tune_batch_norm)
  349. self.softmax = P.Softmax(axis=1)
  350. self.concat = P.Concat(axis=2)
  351. self.expand_dims = P.ExpandDims()
  352. self.reduce_mean = P.ReduceMean()
  353. self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
  354. int(feature_shape[3])),
  355. align_corners=True)
  356. def construct(self, x):
  357. logits = ()
  358. if self.training:
  359. if len(self.image_pyramid) >= 1:
  360. if self.image_pyramid[0] == 1:
  361. logits = self.deeplabv3(x)
  362. else:
  363. x1 = self.samples[0](x)
  364. logits = self.deeplabv3(x1)
  365. logits = self.sample_common(logits)
  366. logits = self.expand_dims(logits, 2)
  367. for i in range(len(self.image_pyramid) - 1):
  368. x_i = self.samples[i + 1](x)
  369. logits_i = self.deeplabv3(x_i)
  370. logits_i = self.sample_common(logits_i)
  371. logits_i = self.expand_dims(logits_i, 2)
  372. logits = self.concat((logits, logits_i))
  373. logits = self.reduce_mean(logits, 2)
  374. return logits
  375. if len(self.infer_scale_sizes) >= 1:
  376. infer_index = len(self.image_pyramid)
  377. x1 = self.samples[infer_index](x)
  378. logits = self.deeplabv3(x1)
  379. logits = self.sample_common(logits)
  380. logits = self.softmax(logits)
  381. logits = self.expand_dims(logits, 2)
  382. for i in range(len(self.infer_scale_sizes) - 1):
  383. x_i = self.samples[i + 1 + infer_index](x)
  384. logits_i = self.deeplabv3(x_i)
  385. logits_i = self.sample_common(logits_i)
  386. logits_i = self.softmax(logits_i)
  387. logits_i = self.expand_dims(logits_i, 2)
  388. logits = self.concat((logits, logits_i))
  389. logits = self.reduce_mean(logits, 2)
  390. return logits
  391. def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid,
  392. infer_scale_sizes, atrous_rates=None, decoder_output_stride=None,
  393. output_stride=16, fine_tune_batch_norm=False):
  394. """
  395. ResNet50 based DeepLabv3 network.
  396. Args:
  397. num_classes (int): Class number.
  398. feature_shape (list): Input image shape, [N,C,H,W].
  399. image_pyramid (list): Input scales for multi-scale feature extraction.
  400. atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
  401. infer_scale_sizes (list): 'The scales to resize images for inference.
  402. decoder_output_stride (int): 'The ratio of input to output spatial resolution'
  403. output_stride (int): 'The ratio of input to output spatial resolution.'
  404. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  405. Returns:
  406. Cell, cell instance of ResNet50 based DeepLabv3 neural network.
  407. Examples:
  408. >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0])
  409. """
  410. return DeepLabV3(num_classes=num_classes,
  411. feature_shape=feature_shape,
  412. backbone=resnet50_dl(fine_tune_batch_norm),
  413. channel=2048,
  414. depth=256,
  415. infer_scale_sizes=infer_scale_sizes,
  416. atrous_rates=atrous_rates,
  417. decoder_output_stride=decoder_output_stride,
  418. output_stride=output_stride,
  419. fine_tune_batch_norm=fine_tune_batch_norm,
  420. image_pyramid=image_pyramid)