You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deeplabv3.py 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the License);
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # httpwww.apache.orglicensesLICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an AS IS BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """DeepLabv3."""
  16. import numpy as np
  17. import mindspore.nn as nn
  18. from mindspore.ops import operations as P
  19. from mindspore.ops.composite import add_flags
  20. from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
  21. DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
  22. class ASPPSampleBlock(nn.Cell):
  23. """ASPP sample block."""
  24. def __init__(self, feature_shape, scale_size, output_stride):
  25. super(ASPPSampleBlock, self).__init__()
  26. sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1
  27. sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1
  28. self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
  29. def construct(self, x):
  30. return self.sample(x)
  31. class ASPP(nn.Cell):
  32. """
  33. ASPP model for DeepLabv3.
  34. Args:
  35. channel (int): Input channel.
  36. depth (int): Output channel.
  37. feature_shape (list): The shape of feature,[h,w].
  38. scale_sizes (list): Input scales for multi-scale feature extraction.
  39. atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
  40. output_stride (int): 'The ratio of input to output spatial resolution.'
  41. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  42. Returns:
  43. Tensor, output tensor.
  44. Examples:
  45. >>> ASPP(channel=2048,256,[14,14],[1],[6],16)
  46. """
  47. def __init__(self, channel, depth, feature_shape, scale_sizes,
  48. atrous_rates, output_stride, fine_tune_batch_norm=False):
  49. super(ASPP, self).__init__()
  50. self.aspp0 = _conv_bn_relu(channel,
  51. depth,
  52. ksize=1,
  53. stride=1,
  54. use_batch_statistics=fine_tune_batch_norm)
  55. self.atrous_rates = []
  56. if atrous_rates is not None:
  57. self.atrous_rates = atrous_rates
  58. self.aspp_pointwise = _conv_bn_relu(channel,
  59. depth,
  60. ksize=1,
  61. stride=1,
  62. use_batch_statistics=fine_tune_batch_norm)
  63. self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel,
  64. channel_multiplier=1,
  65. kernel_size=3,
  66. stride=1,
  67. dilation=1,
  68. pad_mode="valid")
  69. self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm)
  70. self.aspp_depth_relu = nn.ReLU()
  71. self.aspp_depths = []
  72. self.aspp_depth_spacetobatchs = []
  73. self.aspp_depth_batchtospaces = []
  74. for scale_size in scale_sizes:
  75. aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16)
  76. if atrous_rates is None:
  77. break
  78. for rate in atrous_rates:
  79. padding = 0
  80. for j in range(100):
  81. padded_size = rate * j
  82. if padded_size >= aspp_scale_depth_size + 2 * rate:
  83. padding = padded_size - aspp_scale_depth_size - 2 * rate
  84. break
  85. paddings = [[rate, rate + int(padding)],
  86. [rate, rate + int(padding)]]
  87. self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings)
  88. self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch)
  89. crops = [[0, int(padding)], [0, int(padding)]]
  90. self.aspp_depth_batchtospace = BatchToSpace(rate, crops)
  91. self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace)
  92. self.aspp_depths = nn.CellList(self.aspp_depths)
  93. self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs)
  94. self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces)
  95. self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1])))
  96. self.global_poolings = []
  97. for scale_size in scale_sizes:
  98. pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride)
  99. pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride)
  100. self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w))))
  101. self.global_poolings = nn.CellList(self.global_poolings)
  102. self.conv_bn = _conv_bn_relu(channel,
  103. depth,
  104. ksize=1,
  105. stride=1,
  106. use_batch_statistics=fine_tune_batch_norm)
  107. self.samples = []
  108. for scale_size in scale_sizes:
  109. self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride))
  110. self.samples = nn.CellList(self.samples)
  111. self.feature_shape = feature_shape
  112. self.concat = P.Concat(axis=1)
  113. @add_flags(loop_can_unroll=True)
  114. def construct(self, x, scale_index=0):
  115. aspp0 = self.aspp0(x)
  116. aspp1 = self.global_poolings[scale_index](x)
  117. aspp1 = self.conv_bn(aspp1)
  118. aspp1 = self.samples[scale_index](aspp1)
  119. output = self.concat((aspp1, aspp0))
  120. for i in range(len(self.atrous_rates)):
  121. aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x)
  122. aspp_i = self.aspp_depth_depthwiseconv(aspp_i)
  123. aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i)
  124. aspp_i = self.aspp_depth_bn(aspp_i)
  125. aspp_i = self.aspp_depth_relu(aspp_i)
  126. aspp_i = self.aspp_pointwise(aspp_i)
  127. output = self.concat((output, aspp_i))
  128. return output
  129. class DecoderSampleBlock(nn.Cell):
  130. """Decoder sample block."""
  131. def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4):
  132. super(DecoderSampleBlock, self).__init__()
  133. sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1
  134. sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1
  135. self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
  136. def construct(self, x):
  137. return self.sample(x)
  138. class Decoder(nn.Cell):
  139. """
  140. Decode module for DeepLabv3.
  141. Args:
  142. low_level_channel (int): Low level input channel
  143. channel (int): Input channel.
  144. depth (int): Output channel.
  145. feature_shape (list): 'Input image shape, [N,C,H,W].'
  146. scale_sizes (list): 'Input scales for multi-scale feature extraction.'
  147. decoder_output_stride (int): 'The ratio of input to output spatial resolution'
  148. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  149. Returns:
  150. Tensor, output tensor.
  151. Examples:
  152. >>> Decoder(256, 100, [56,56])
  153. """
  154. def __init__(self,
  155. low_level_channel,
  156. channel,
  157. depth,
  158. feature_shape,
  159. scale_sizes,
  160. decoder_output_stride,
  161. fine_tune_batch_norm):
  162. super(Decoder, self).__init__()
  163. self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1,
  164. pad_mode="same", use_batch_statistics=fine_tune_batch_norm)
  165. self.decoder_depth0 = _deep_conv_bn_relu(channel + 48,
  166. channel_multiplier=1,
  167. ksize=3,
  168. stride=1,
  169. pad_mode="same",
  170. dilation=1,
  171. use_batch_statistics=fine_tune_batch_norm)
  172. self.decoder_pointwise0 = _conv_bn_relu(channel + 48,
  173. depth,
  174. ksize=1,
  175. stride=1,
  176. use_batch_statistics=fine_tune_batch_norm)
  177. self.decoder_depth1 = _deep_conv_bn_relu(depth,
  178. channel_multiplier=1,
  179. ksize=3,
  180. stride=1,
  181. pad_mode="same",
  182. dilation=1,
  183. use_batch_statistics=fine_tune_batch_norm)
  184. self.decoder_pointwise1 = _conv_bn_relu(depth,
  185. depth,
  186. ksize=1,
  187. stride=1,
  188. use_batch_statistics=fine_tune_batch_norm)
  189. self.depth = depth
  190. self.concat = P.Concat(axis=1)
  191. self.samples = []
  192. for scale_size in scale_sizes:
  193. self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride))
  194. self.samples = nn.CellList(self.samples)
  195. def construct(self, x, low_level_feature, scale_index):
  196. low_level_feature = self.feature_projection(low_level_feature)
  197. low_level_feature = self.samples[scale_index](low_level_feature)
  198. x = self.samples[scale_index](x)
  199. output = self.concat((x, low_level_feature))
  200. output = self.decoder_depth0(output)
  201. output = self.decoder_pointwise0(output)
  202. output = self.decoder_depth1(output)
  203. output = self.decoder_pointwise1(output)
  204. return output
  205. class SingleDeepLabV3(nn.Cell):
  206. """
  207. DeepLabv3 Network.
  208. Args:
  209. num_classes (int): Class number.
  210. feature_shape (list): Input image shape, [N,C,H,W].
  211. backbone (Cell): Backbone Network.
  212. channel (int): Resnet output channel.
  213. depth (int): ASPP block depth.
  214. scale_sizes (list): Input scales for multi-scale feature extraction.
  215. atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
  216. decoder_output_stride (int): 'The ratio of input to output spatial resolution'
  217. output_stride (int): 'The ratio of input to output spatial resolution.'
  218. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  219. Returns:
  220. Tensor, output tensor.
  221. Examples:
  222. >>> SingleDeepLabV3(num_classes=10,
  223. >>> feature_shape=[1,3,224,224],
  224. >>> backbone=resnet50_dl(),
  225. >>> channel=2048,
  226. >>> depth=256)
  227. >>> scale_sizes=[1.0])
  228. >>> atrous_rates=[6])
  229. >>> decoder_output_stride=4)
  230. >>> output_stride=16)
  231. """
  232. def __init__(self,
  233. num_classes,
  234. feature_shape,
  235. backbone,
  236. channel,
  237. depth,
  238. scale_sizes,
  239. atrous_rates,
  240. decoder_output_stride,
  241. output_stride,
  242. fine_tune_batch_norm=False):
  243. super(SingleDeepLabV3, self).__init__()
  244. self.num_classes = num_classes
  245. self.channel = channel
  246. self.depth = depth
  247. self.scale_sizes = []
  248. for scale_size in np.sort(scale_sizes):
  249. self.scale_sizes.append(scale_size)
  250. self.net = backbone
  251. self.aspp = ASPP(channel=self.channel,
  252. depth=self.depth,
  253. feature_shape=[feature_shape[2],
  254. feature_shape[3]],
  255. scale_sizes=self.scale_sizes,
  256. atrous_rates=atrous_rates,
  257. output_stride=output_stride,
  258. fine_tune_batch_norm=fine_tune_batch_norm)
  259. atrous_rates_len = 0
  260. if atrous_rates is not None:
  261. atrous_rates_len = len(atrous_rates)
  262. self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth,
  263. ksize=1,
  264. stride=1,
  265. use_batch_statistics=fine_tune_batch_norm)
  266. self.fc2 = nn.Conv2d(depth,
  267. num_classes,
  268. kernel_size=1,
  269. stride=1,
  270. has_bias=True)
  271. self.upsample = P.ResizeBilinear((int(feature_shape[2]),
  272. int(feature_shape[3])),
  273. align_corners=True)
  274. self.samples = []
  275. for scale_size in self.scale_sizes:
  276. self.samples.append(SampleBlock(feature_shape, scale_size))
  277. self.samples = nn.CellList(self.samples)
  278. self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]),
  279. float(feature_shape[3])]
  280. self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1)))
  281. self.dropout = nn.Dropout(keep_prob=0.9)
  282. self.shape = P.Shape()
  283. self.decoder_output_stride = decoder_output_stride
  284. if decoder_output_stride is not None:
  285. self.decoder = Decoder(low_level_channel=depth,
  286. channel=depth,
  287. depth=depth,
  288. feature_shape=[feature_shape[2],
  289. feature_shape[3]],
  290. scale_sizes=self.scale_sizes,
  291. decoder_output_stride=decoder_output_stride,
  292. fine_tune_batch_norm=fine_tune_batch_norm)
  293. def construct(self, x, scale_index=0):
  294. x = (2.0 / 255.0) * x - 1.0
  295. x = self.pad(x)
  296. low_level_feature, feature_map = self.net(x)
  297. for scale_size in self.scale_sizes:
  298. if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2:
  299. output = self.aspp(feature_map, scale_index)
  300. output = self.fc1(output)
  301. if self.decoder_output_stride is not None:
  302. output = self.decoder(output, low_level_feature, scale_index)
  303. output = self.fc2(output)
  304. output = self.samples[scale_index](output)
  305. return output
  306. scale_index += 1
  307. return feature_map
  308. class SampleBlock(nn.Cell):
  309. """Sample block."""
  310. def __init__(self,
  311. feature_shape,
  312. scale_size=1.0):
  313. super(SampleBlock, self).__init__()
  314. sample_h = np.ceil(float(feature_shape[2]) * scale_size)
  315. sample_w = np.ceil(float(feature_shape[3]) * scale_size)
  316. self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True)
  317. def construct(self, x):
  318. return self.sample(x)
  319. class DeepLabV3(nn.Cell):
  320. """DeepLabV3 model."""
  321. def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates,
  322. decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid):
  323. super(DeepLabV3, self).__init__()
  324. self.infer_scale_sizes = []
  325. if infer_scale_sizes is not None:
  326. self.infer_scale_sizes = infer_scale_sizes
  327. self.infer_scale_sizes = infer_scale_sizes
  328. if image_pyramid is None:
  329. image_pyramid = [1.0]
  330. self.image_pyramid = image_pyramid
  331. scale_sizes = []
  332. for pyramid in image_pyramid:
  333. scale_sizes.append(pyramid)
  334. for scale in infer_scale_sizes:
  335. scale_sizes.append(scale)
  336. self.samples = []
  337. for scale_size in scale_sizes:
  338. self.samples.append(SampleBlock(feature_shape, scale_size))
  339. self.samples = nn.CellList(self.samples)
  340. self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes,
  341. feature_shape=feature_shape,
  342. backbone=resnet50_dl(fine_tune_batch_norm),
  343. channel=channel,
  344. depth=depth,
  345. scale_sizes=scale_sizes,
  346. atrous_rates=atrous_rates,
  347. decoder_output_stride=decoder_output_stride,
  348. output_stride=output_stride,
  349. fine_tune_batch_norm=fine_tune_batch_norm)
  350. self.softmax = P.Softmax(axis=1)
  351. self.concat = P.Concat(axis=2)
  352. self.expand_dims = P.ExpandDims()
  353. self.reduce_mean = P.ReduceMean()
  354. self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
  355. int(feature_shape[3])),
  356. align_corners=True)
  357. def construct(self, x):
  358. logits = ()
  359. if self.training:
  360. if len(self.image_pyramid) >= 1:
  361. if self.image_pyramid[0] == 1:
  362. logits = self.deeplabv3(x)
  363. else:
  364. x1 = self.samples[0](x)
  365. logits = self.deeplabv3(x1)
  366. logits = self.sample_common(logits)
  367. logits = self.expand_dims(logits, 2)
  368. for i in range(len(self.image_pyramid) - 1):
  369. x_i = self.samples[i + 1](x)
  370. logits_i = self.deeplabv3(x_i)
  371. logits_i = self.sample_common(logits_i)
  372. logits_i = self.expand_dims(logits_i, 2)
  373. logits = self.concat((logits, logits_i))
  374. logits = self.reduce_mean(logits, 2)
  375. return logits
  376. if len(self.infer_scale_sizes) >= 1:
  377. infer_index = len(self.image_pyramid)
  378. x1 = self.samples[infer_index](x)
  379. logits = self.deeplabv3(x1)
  380. logits = self.sample_common(logits)
  381. logits = self.softmax(logits)
  382. logits = self.expand_dims(logits, 2)
  383. for i in range(len(self.infer_scale_sizes) - 1):
  384. x_i = self.samples[i + 1 + infer_index](x)
  385. logits_i = self.deeplabv3(x_i)
  386. logits_i = self.sample_common(logits_i)
  387. logits_i = self.softmax(logits_i)
  388. logits_i = self.expand_dims(logits_i, 2)
  389. logits = self.concat((logits, logits_i))
  390. logits = self.reduce_mean(logits, 2)
  391. return logits
  392. def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid,
  393. infer_scale_sizes, atrous_rates=None, decoder_output_stride=None,
  394. output_stride=16, fine_tune_batch_norm=False):
  395. """
  396. ResNet50 based DeepLabv3 network.
  397. Args:
  398. num_classes (int): Class number.
  399. feature_shape (list): Input image shape, [N,C,H,W].
  400. image_pyramid (list): Input scales for multi-scale feature extraction.
  401. atrous_rates (list): Atrous rates for atrous spatial pyramid pooling.
  402. infer_scale_sizes (list): 'The scales to resize images for inference.
  403. decoder_output_stride (int): 'The ratio of input to output spatial resolution'
  404. output_stride (int): 'The ratio of input to output spatial resolution.'
  405. fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not'
  406. Returns:
  407. Cell, cell instance of ResNet50 based DeepLabv3 neural network.
  408. Examples:
  409. >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0])
  410. """
  411. return DeepLabV3(num_classes=num_classes,
  412. feature_shape=feature_shape,
  413. backbone=resnet50_dl(fine_tune_batch_norm),
  414. channel=2048,
  415. depth=256,
  416. infer_scale_sizes=infer_scale_sizes,
  417. atrous_rates=atrous_rates,
  418. decoder_output_stride=decoder_output_stride,
  419. output_stride=output_stride,
  420. fine_tune_batch_norm=fine_tune_batch_norm,
  421. image_pyramid=image_pyramid)