You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

retinanet.py 19 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """retinanet based resnet."""
  16. import mindspore.common.dtype as mstype
  17. import mindspore as ms
  18. import mindspore.nn as nn
  19. from mindspore import context, Tensor
  20. from mindspore.context import ParallelMode
  21. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  22. from mindspore.communication.management import get_group_size
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import composite as C
  26. def _bn(channel):
  27. return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.97,
  28. gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
  29. class ConvBNReLU(nn.Cell):
  30. """
  31. Convolution/Depthwise fused with Batchnorm and ReLU block definition.
  32. Args:
  33. in_planes (int): Input channel.
  34. out_planes (int): Output channel.
  35. kernel_size (int): Input kernel size.
  36. stride (int): Stride size for the first convolutional layer. Default: 1.
  37. groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
  38. Returns:
  39. Tensor, output tensor.
  40. Examples:
  41. >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
  42. """
  43. def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
  44. super(ConvBNReLU, self).__init__()
  45. padding = 0
  46. conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same',
  47. padding=padding)
  48. layers = [conv, _bn(out_planes), nn.ReLU()]
  49. self.features = nn.SequentialCell(layers)
  50. def construct(self, x):
  51. output = self.features(x)
  52. return output
  53. class ResidualBlock(nn.Cell):
  54. """
  55. ResNet V1 residual block definition.
  56. Args:
  57. in_channel (int): Input channel.
  58. out_channel (int): Output channel.
  59. stride (int): Stride size for the first convolutional layer. Default: 1.
  60. Returns:
  61. Tensor, output tensor.
  62. Examples:
  63. >>> ResidualBlock(3, 256, stride=2)
  64. """
  65. expansion = 4
  66. def __init__(self,
  67. in_channel,
  68. out_channel,
  69. stride=1):
  70. super(ResidualBlock, self).__init__()
  71. channel = out_channel // self.expansion
  72. self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
  73. self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
  74. self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', padding=0,
  75. has_bn=True, activation='relu')
  76. self.down_sample = False
  77. if stride != 1 or in_channel != out_channel:
  78. self.down_sample = True
  79. self.down_sample_layer = None
  80. if self.down_sample:
  81. self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
  82. kernel_size=1, stride=stride,
  83. pad_mode='same', padding=0, has_bn=True, activation='relu')
  84. self.add = P.TensorAdd()
  85. self.relu = P.ReLU()
  86. def construct(self, x):
  87. identity = x
  88. out = self.conv1(x)
  89. out = self.conv2(out)
  90. out = self.conv3(out)
  91. if self.down_sample:
  92. identity = self.down_sample_layer(identity)
  93. out = self.add(out, identity)
  94. out = self.relu(out)
  95. return out
  96. class FlattenConcat(nn.Cell):
  97. """
  98. Concatenate predictions into a single tensor.
  99. Args:
  100. config (dict): The default config of retinanet.
  101. Returns:
  102. Tensor, flatten predictions.
  103. """
  104. def __init__(self, config):
  105. super(FlattenConcat, self).__init__()
  106. self.num_retinanet_boxes = config.num_retinanet_boxes
  107. self.concat = P.Concat(axis=1)
  108. self.transpose = P.Transpose()
  109. def construct(self, inputs):
  110. output = ()
  111. batch_size = F.shape(inputs[0])[0]
  112. for x in inputs:
  113. x = self.transpose(x, (0, 2, 3, 1))
  114. output += (F.reshape(x, (batch_size, -1)),)
  115. res = self.concat(output)
  116. return F.reshape(res, (batch_size, self.num_retinanet_boxes, -1))
  117. def ClassificationModel(in_channel, num_anchors, kernel_size=3,
  118. stride=1, pad_mod='same', num_classes=81, feature_size=256):
  119. conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
  120. conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
  121. conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
  122. conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
  123. conv5 = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, pad_mode='same')
  124. return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5])
  125. def RegressionModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', feature_size=256):
  126. conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
  127. conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
  128. conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
  129. conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
  130. conv5 = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, pad_mode='same')
  131. return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5])
  132. class MultiBox(nn.Cell):
  133. """
  134. Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
  135. Args:
  136. config (dict): The default config of retinanet.
  137. Returns:
  138. Tensor, localization predictions.
  139. Tensor, class conf scores.
  140. """
  141. def __init__(self, config):
  142. super(MultiBox, self).__init__()
  143. out_channels = config.extras_out_channels
  144. num_default = config.num_default
  145. loc_layers = []
  146. cls_layers = []
  147. for k, out_channel in enumerate(out_channels):
  148. loc_layers += [RegressionModel(in_channel=out_channel, num_anchors=num_default[k])]
  149. cls_layers += [ClassificationModel(in_channel=out_channel, num_anchors=num_default[k])]
  150. self.multi_loc_layers = nn.layer.CellList(loc_layers)
  151. self.multi_cls_layers = nn.layer.CellList(cls_layers)
  152. self.flatten_concat = FlattenConcat(config)
  153. def construct(self, inputs):
  154. loc_outputs = ()
  155. cls_outputs = ()
  156. for i in range(len(self.multi_loc_layers)):
  157. loc_outputs += (self.multi_loc_layers[i](inputs[i]),)
  158. cls_outputs += (self.multi_cls_layers[i](inputs[i]),)
  159. return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)
  160. class SigmoidFocalClassificationLoss(nn.Cell):
  161. """"
  162. Sigmoid focal-loss for classification.
  163. Args:
  164. gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0
  165. alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25
  166. Returns:
  167. Tensor, the focal loss.
  168. """
  169. def __init__(self, gamma=2.0, alpha=0.25):
  170. super(SigmoidFocalClassificationLoss, self).__init__()
  171. self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits()
  172. self.sigmoid = P.Sigmoid()
  173. self.pow = P.Pow()
  174. self.onehot = P.OneHot()
  175. self.on_value = Tensor(1.0, mstype.float32)
  176. self.off_value = Tensor(0.0, mstype.float32)
  177. self.gamma = gamma
  178. self.alpha = alpha
  179. def construct(self, logits, label):
  180. label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
  181. sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label)
  182. sigmoid = self.sigmoid(logits)
  183. label = F.cast(label, mstype.float32)
  184. p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
  185. modulating_factor = self.pow(1 - p_t, self.gamma)
  186. alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
  187. focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy
  188. return focal_loss
  189. class retinanetWithLossCell(nn.Cell):
  190. """"
  191. Provide retinanet training loss through network.
  192. Args:
  193. network (Cell): The training network.
  194. config (dict): retinanet config.
  195. Returns:
  196. Tensor, the loss of the network.
  197. """
  198. def __init__(self, network, config):
  199. super(retinanetWithLossCell, self).__init__()
  200. self.network = network
  201. self.less = P.Less()
  202. self.tile = P.Tile()
  203. self.reduce_sum = P.ReduceSum()
  204. self.reduce_mean = P.ReduceMean()
  205. self.expand_dims = P.ExpandDims()
  206. self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha)
  207. self.loc_loss = nn.SmoothL1Loss()
  208. def construct(self, x, gt_loc, gt_label, num_matched_boxes):
  209. pred_loc, pred_label = self.network(x)
  210. mask = F.cast(self.less(0, gt_label), mstype.float32)
  211. num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
  212. # Localization Loss
  213. mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4))
  214. smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc
  215. loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1)
  216. # Classification Loss
  217. loss_cls = self.class_loss(pred_label, gt_label)
  218. loss_cls = self.reduce_sum(loss_cls, (1, 2))
  219. return self.reduce_sum((loss_cls + loss_loc) /num_matched_boxes)
  220. class TrainingWrapper(nn.Cell):
  221. """
  222. Encapsulation class of retinanet network training.
  223. Append an optimizer to the training network after that the construct
  224. function can be called to create the backward graph.
  225. Args:
  226. network (Cell): The training network. Note that loss function should have been added.
  227. optimizer (Optimizer): Optimizer for updating the weights.
  228. sens (Number): The adjust parameter. Default: 1.0.
  229. """
  230. def __init__(self, network, optimizer, sens=1.0):
  231. super(TrainingWrapper, self).__init__(auto_prefix=False)
  232. self.network = network
  233. self.network.set_grad()
  234. self.weights = ms.ParameterTuple(network.trainable_params())
  235. self.optimizer = optimizer
  236. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  237. self.sens = sens
  238. self.reducer_flag = False
  239. self.grad_reducer = None
  240. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  241. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  242. self.reducer_flag = True
  243. if self.reducer_flag:
  244. mean = context.get_auto_parallel_context("gradients_mean")
  245. if auto_parallel_context().get_device_num_is_set():
  246. degree = context.get_auto_parallel_context("device_num")
  247. else:
  248. degree = get_group_size()
  249. self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
  250. def construct(self, *args):
  251. weights = self.weights
  252. loss = self.network(*args)
  253. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  254. grads = self.grad(self.network, weights)(*args, sens)
  255. if self.reducer_flag:
  256. # apply grad reducer on grads
  257. grads = self.grad_reducer(grads)
  258. return F.depend(loss, self.optimizer(grads))
  259. class resnet(nn.Cell):
  260. """
  261. ResNet architecture.
  262. Args:
  263. block (Cell): Block for network.
  264. layer_nums (list): Numbers of block in different layers.
  265. in_channels (list): Input channel in each layer.
  266. out_channels (list): Output channel in each layer.
  267. strides (list): Stride size in each layer.
  268. num_classes (int): The number of classes that the training images are belonging to.
  269. Returns:
  270. Tensor, output tensor.
  271. Examples:
  272. >>> ResNet(ResidualBlock,
  273. >>> [3, 4, 6, 3],
  274. >>> [64, 256, 512, 1024],
  275. >>> [256, 512, 1024, 2048],
  276. >>> [1, 2, 2, 2],
  277. >>> 10)
  278. """
  279. def __init__(self,
  280. block,
  281. layer_nums,
  282. in_channels,
  283. out_channels,
  284. strides,
  285. num_classes):
  286. super(resnet, self).__init__()
  287. if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
  288. raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
  289. self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
  290. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  291. self.layer1 = self._make_layer(block,
  292. layer_nums[0],
  293. in_channel=in_channels[0],
  294. out_channel=out_channels[0],
  295. stride=strides[0])
  296. self.layer2 = self._make_layer(block,
  297. layer_nums[1],
  298. in_channel=in_channels[1],
  299. out_channel=out_channels[1],
  300. stride=strides[1])
  301. self.layer3 = self._make_layer(block,
  302. layer_nums[2],
  303. in_channel=in_channels[2],
  304. out_channel=out_channels[2],
  305. stride=strides[2])
  306. self.layer4 = self._make_layer(block,
  307. layer_nums[3],
  308. in_channel=in_channels[3],
  309. out_channel=out_channels[3],
  310. stride=strides[3])
  311. def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
  312. """
  313. Make stage network of ResNet.
  314. Args:
  315. block (Cell): Resnet block.
  316. layer_num (int): Layer number.
  317. in_channel (int): Input channel.
  318. out_channel (int): Output channel.
  319. stride (int): Stride size for the first convolutional layer.
  320. Returns:
  321. SequentialCell, the output layer.
  322. Examples:
  323. >>> _make_layer(ResidualBlock, 3, 128, 256, 2)
  324. """
  325. layers = []
  326. resnet_block = ResidualBlock(in_channel, out_channel, stride=stride)
  327. layers.append(resnet_block)
  328. for _ in range(1, layer_num):
  329. resnet_block = ResidualBlock(out_channel, out_channel, stride=1)
  330. layers.append(resnet_block)
  331. return nn.SequentialCell(layers)
  332. def construct(self, x):
  333. x = self.conv1(x)
  334. C1 = self.maxpool(x)
  335. C2 = self.layer1(C1)
  336. C3 = self.layer2(C2)
  337. C4 = self.layer3(C3)
  338. C5 = self.layer4(C4)
  339. return C3, C4, C5
  340. def resnet50(num_classes):
  341. """
  342. Get ResNet50 neural network.
  343. Args:
  344. class_num (int): Class number.
  345. Returns:
  346. Cell, cell instance of ResNet50 neural network.
  347. Examples:
  348. >>> net = resnet50_quant(10)
  349. """
  350. return resnet(ResidualBlock,
  351. [3, 4, 6, 3],
  352. [64, 256, 512, 1024],
  353. [256, 512, 1024, 2048],
  354. [1, 2, 2, 2],
  355. num_classes)
  356. class retinanet50(nn.Cell):
  357. def __init__(self, backbone, config, is_training=True):
  358. super(retinanet50, self).__init__()
  359. self.backbone = backbone
  360. feature_size = config.feature_size
  361. self.P5_1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, pad_mode='same')
  362. self.P_upsample1 = P.ResizeNearestNeighbor((feature_size[1], feature_size[1]))
  363. self.P5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
  364. self.P4_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, pad_mode='same')
  365. self.P_upsample2 = P.ResizeNearestNeighbor((feature_size[0], feature_size[0]))
  366. self.P4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
  367. self.P3_1 = nn.Conv2d(512, 256, kernel_size=1, stride=1, pad_mode='same')
  368. self.P3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, pad_mode='same')
  369. self.P6_0 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, pad_mode='same')
  370. self.P7_1 = nn.ReLU()
  371. self.P7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=2, pad_mode='same')
  372. self.multi_box = MultiBox(config)
  373. self.is_training = is_training
  374. if not is_training:
  375. self.activation = P.Sigmoid()
  376. def construct(self, x):
  377. C3, C4, C5 = self.backbone(x)
  378. P5 = self.P5_1(C5)
  379. P5_upsampled = self.P_upsample1(P5)
  380. P5 = self.P5_2(P5)
  381. P4 = self.P4_1(C4)
  382. P4 = P5_upsampled +P4
  383. P4_upsampled = self.P_upsample2(P4)
  384. P4 = self.P4_2(P4)
  385. P3 = self.P3_1(C3)
  386. P3 = P4_upsampled + P3
  387. P3 = self.P3_2(P3)
  388. P6 = self.P6_0(C5)
  389. P7 = self.P7_1(P6)
  390. P7 = self.P7_2(P7)
  391. multi_feature = (P3, P4, P5, P6, P7)
  392. pred_loc, pred_label = self.multi_box(multi_feature)
  393. return pred_loc, pred_label
  394. class retinanetInferWithDecoder(nn.Cell):
  395. """
  396. retinanet Infer wrapper to decode the bbox locations.
  397. Args:
  398. network (Cell): the origin retinanet infer network without bbox decoder.
  399. default_boxes (Tensor): the default_boxes from anchor generator
  400. config (dict): retinanet config
  401. Returns:
  402. Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1)
  403. Tensor, the prediction labels.
  404. """
  405. def __init__(self, network, default_boxes, config):
  406. super(retinanetInferWithDecoder, self).__init__()
  407. self.network = network
  408. self.default_boxes = default_boxes
  409. self.prior_scaling_xy = config.prior_scaling[0]
  410. self.prior_scaling_wh = config.prior_scaling[1]
  411. def construct(self, x):
  412. pred_loc, pred_label = self.network(x)
  413. default_bbox_xy = self.default_boxes[..., :2]
  414. default_bbox_wh = self.default_boxes[..., 2:]
  415. pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy
  416. pred_wh = P.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh
  417. pred_xy_0 = pred_xy - pred_wh / 2.0
  418. pred_xy_1 = pred_xy + pred_wh / 2.0
  419. pred_xy = P.Concat(-1)((pred_xy_0, pred_xy_1))
  420. pred_xy = P.Maximum()(pred_xy, 0)
  421. pred_xy = P.Minimum()(pred_xy, 1)
  422. return pred_xy, pred_label