You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolov3.py 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YOLOv3 based on ResNet18."""
  16. import numpy as np
  17. import mindspore as ms
  18. import mindspore.nn as nn
  19. from mindspore import context, Tensor
  20. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  21. from mindspore.communication.management import get_group_size
  22. from mindspore.common.initializer import TruncatedNormal
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import composite as C
  26. def weight_variable():
  27. """Weight variable."""
  28. return TruncatedNormal(0.02)
  29. class _conv2d(nn.Cell):
  30. """Create Conv2D with padding."""
  31. def __init__(self, in_channels, out_channels, kernel_size, stride=1):
  32. super(_conv2d, self).__init__()
  33. self.conv = nn.Conv2d(in_channels, out_channels,
  34. kernel_size=kernel_size, stride=stride, padding=0, pad_mode='same',
  35. weight_init=weight_variable())
  36. def construct(self, x):
  37. x = self.conv(x)
  38. return x
  39. def _fused_bn(channels, momentum=0.99):
  40. """Get a fused batchnorm."""
  41. return nn.BatchNorm2d(channels, momentum=momentum)
  42. def _conv_bn_relu(in_channel,
  43. out_channel,
  44. ksize,
  45. stride=1,
  46. padding=0,
  47. dilation=1,
  48. alpha=0.1,
  49. momentum=0.99,
  50. pad_mode="same"):
  51. """Get a conv2d batchnorm and relu layer."""
  52. return nn.SequentialCell(
  53. [nn.Conv2d(in_channel,
  54. out_channel,
  55. kernel_size=ksize,
  56. stride=stride,
  57. padding=padding,
  58. dilation=dilation,
  59. pad_mode=pad_mode),
  60. nn.BatchNorm2d(out_channel, momentum=momentum),
  61. nn.LeakyReLU(alpha)]
  62. )
  63. class BasicBlock(nn.Cell):
  64. """
  65. ResNet basic block.
  66. Args:
  67. in_channels (int): Input channel.
  68. out_channels (int): Output channel.
  69. stride (int): Stride size for the initial convolutional layer. Default:1.
  70. momentum (float): Momentum for batchnorm layer. Default:0.1.
  71. Returns:
  72. Tensor, output tensor.
  73. Examples:
  74. BasicBlock(3,256,stride=2,down_sample=True).
  75. """
  76. expansion = 1
  77. def __init__(self,
  78. in_channels,
  79. out_channels,
  80. stride=1,
  81. momentum=0.99):
  82. super(BasicBlock, self).__init__()
  83. self.conv1 = _conv2d(in_channels, out_channels, 3, stride=stride)
  84. self.bn1 = _fused_bn(out_channels, momentum=momentum)
  85. self.conv2 = _conv2d(out_channels, out_channels, 3)
  86. self.bn2 = _fused_bn(out_channels, momentum=momentum)
  87. self.relu = P.ReLU()
  88. self.down_sample_layer = None
  89. self.downsample = (in_channels != out_channels)
  90. if self.downsample:
  91. self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
  92. self.add = P.TensorAdd()
  93. def construct(self, x):
  94. identity = x
  95. x = self.conv1(x)
  96. x = self.bn1(x)
  97. x = self.relu(x)
  98. x = self.conv2(x)
  99. x = self.bn2(x)
  100. if self.downsample:
  101. identity = self.down_sample_layer(identity)
  102. out = self.add(x, identity)
  103. out = self.relu(out)
  104. return out
  105. class ResNet(nn.Cell):
  106. """
  107. ResNet network.
  108. Args:
  109. block (Cell): Block for network.
  110. layer_nums (list): Numbers of different layers.
  111. in_channels (int): Input channel.
  112. out_channels (int): Output channel.
  113. num_classes (int): Class number. Default:100.
  114. Returns:
  115. Tensor, output tensor.
  116. Examples:
  117. ResNet(ResidualBlock,
  118. [3, 4, 6, 3],
  119. [64, 256, 512, 1024],
  120. [256, 512, 1024, 2048],
  121. 100).
  122. """
  123. def __init__(self,
  124. block,
  125. layer_nums,
  126. in_channels,
  127. out_channels,
  128. strides=None,
  129. num_classes=80):
  130. super(ResNet, self).__init__()
  131. if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
  132. raise ValueError("the length of "
  133. "layer_num, inchannel, outchannel list must be 4!")
  134. self.conv1 = _conv2d(3, 64, 7, stride=2)
  135. self.bn1 = _fused_bn(64)
  136. self.relu = P.ReLU()
  137. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
  138. self.layer1 = self._make_layer(block,
  139. layer_nums[0],
  140. in_channel=in_channels[0],
  141. out_channel=out_channels[0],
  142. stride=strides[0])
  143. self.layer2 = self._make_layer(block,
  144. layer_nums[1],
  145. in_channel=in_channels[1],
  146. out_channel=out_channels[1],
  147. stride=strides[1])
  148. self.layer3 = self._make_layer(block,
  149. layer_nums[2],
  150. in_channel=in_channels[2],
  151. out_channel=out_channels[2],
  152. stride=strides[2])
  153. self.layer4 = self._make_layer(block,
  154. layer_nums[3],
  155. in_channel=in_channels[3],
  156. out_channel=out_channels[3],
  157. stride=strides[3])
  158. self.num_classes = num_classes
  159. if num_classes:
  160. self.reduce_mean = P.ReduceMean(keep_dims=True)
  161. self.end_point = nn.Dense(out_channels[3], num_classes, has_bias=True,
  162. weight_init=weight_variable(),
  163. bias_init=weight_variable())
  164. self.squeeze = P.Squeeze(axis=(2, 3))
  165. def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
  166. """
  167. Make Layer for ResNet.
  168. Args:
  169. block (Cell): Resnet block.
  170. layer_num (int): Layer number.
  171. in_channel (int): Input channel.
  172. out_channel (int): Output channel.
  173. stride (int): Stride size for the initial convolutional layer.
  174. Returns:
  175. SequentialCell, the output layer.
  176. Examples:
  177. _make_layer(BasicBlock, 3, 128, 256, 2).
  178. """
  179. layers = []
  180. resblk = block(in_channel, out_channel, stride=stride)
  181. layers.append(resblk)
  182. for _ in range(1, layer_num - 1):
  183. resblk = block(out_channel, out_channel, stride=1)
  184. layers.append(resblk)
  185. resblk = block(out_channel, out_channel, stride=1)
  186. layers.append(resblk)
  187. return nn.SequentialCell(layers)
  188. def construct(self, x):
  189. x = self.conv1(x)
  190. x = self.bn1(x)
  191. x = self.relu(x)
  192. c1 = self.maxpool(x)
  193. c2 = self.layer1(c1)
  194. c3 = self.layer2(c2)
  195. c4 = self.layer3(c3)
  196. c5 = self.layer4(c4)
  197. out = c5
  198. if self.num_classes:
  199. out = self.reduce_mean(c5, (2, 3))
  200. out = self.squeeze(out)
  201. out = self.end_point(out)
  202. return c3, c4, out
  203. def resnet18(class_num=10):
  204. """
  205. Get ResNet18 neural network.
  206. Args:
  207. class_num (int): Class number.
  208. Returns:
  209. Cell, cell instance of ResNet18 neural network.
  210. Examples:
  211. resnet18(100).
  212. """
  213. return ResNet(BasicBlock,
  214. [2, 2, 2, 2],
  215. [64, 64, 128, 256],
  216. [64, 128, 256, 512],
  217. [1, 2, 2, 2],
  218. num_classes=class_num)
  219. class YoloBlock(nn.Cell):
  220. """
  221. YoloBlock for YOLOv3.
  222. Args:
  223. in_channels (int): Input channel.
  224. out_chls (int): Middle channel.
  225. out_channels (int): Output channel.
  226. Returns:
  227. Tuple, tuple of output tensor,(f1,f2,f3).
  228. Examples:
  229. YoloBlock(1024, 512, 255).
  230. """
  231. def __init__(self, in_channels, out_chls, out_channels):
  232. super(YoloBlock, self).__init__()
  233. out_chls_2 = out_chls * 2
  234. self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1)
  235. self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
  236. self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
  237. self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
  238. self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
  239. self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
  240. self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True)
  241. def construct(self, x):
  242. c1 = self.conv0(x)
  243. c2 = self.conv1(c1)
  244. c3 = self.conv2(c2)
  245. c4 = self.conv3(c3)
  246. c5 = self.conv4(c4)
  247. c6 = self.conv5(c5)
  248. out = self.conv6(c6)
  249. return c5, out
  250. class YOLOv3(nn.Cell):
  251. """
  252. YOLOv3 Network.
  253. Note:
  254. backbone = resnet18.
  255. Args:
  256. feature_shape (list): Input image shape, [N,C,H,W].
  257. backbone_shape (list): resnet18 output channels shape.
  258. backbone (Cell): Backbone Network.
  259. out_channel (int): Output channel.
  260. Returns:
  261. Tensor, output tensor.
  262. Examples:
  263. YOLOv3(feature_shape=[1,3,416,416],
  264. backbone_shape=[64, 128, 256, 512, 1024]
  265. backbone=darknet53(),
  266. out_channel=255).
  267. """
  268. def __init__(self, feature_shape, backbone_shape, backbone, out_channel):
  269. super(YOLOv3, self).__init__()
  270. self.out_channel = out_channel
  271. self.net = backbone
  272. self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
  273. self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1)
  274. self.upsample1 = P.ResizeNearestNeighbor((feature_shape[2]//16, feature_shape[3]//16))
  275. self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3],
  276. out_chls=backbone_shape[-3],
  277. out_channels=out_channel)
  278. self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1)
  279. self.upsample2 = P.ResizeNearestNeighbor((feature_shape[2]//8, feature_shape[3]//8))
  280. self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4],
  281. out_chls=backbone_shape[-4],
  282. out_channels=out_channel)
  283. self.concat = P.Concat(axis=1)
  284. def construct(self, x):
  285. # input_shape of x is (batch_size, 3, h, w)
  286. # feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
  287. # feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
  288. # feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
  289. feature_map1, feature_map2, feature_map3 = self.net(x)
  290. con1, big_object_output = self.backblock0(feature_map3)
  291. con1 = self.conv1(con1)
  292. ups1 = self.upsample1(con1)
  293. con1 = self.concat((ups1, feature_map2))
  294. con2, medium_object_output = self.backblock1(con1)
  295. con2 = self.conv2(con2)
  296. ups2 = self.upsample2(con2)
  297. con3 = self.concat((ups2, feature_map1))
  298. _, small_object_output = self.backblock2(con3)
  299. return big_object_output, medium_object_output, small_object_output
  300. class DetectionBlock(nn.Cell):
  301. """
  302. YOLOv3 detection Network. It will finally output the detection result.
  303. Args:
  304. scale (str): Character, scale.
  305. config (Class): YOLOv3 config.
  306. Returns:
  307. Tuple, tuple of output tensor,(f1,f2,f3).
  308. Examples:
  309. DetectionBlock(scale='l',stride=32).
  310. """
  311. def __init__(self, scale, config):
  312. super(DetectionBlock, self).__init__()
  313. self.config = config
  314. if scale == 's':
  315. idx = (0, 1, 2)
  316. elif scale == 'm':
  317. idx = (3, 4, 5)
  318. elif scale == 'l':
  319. idx = (6, 7, 8)
  320. else:
  321. raise KeyError("Invalid scale value for DetectionBlock")
  322. self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
  323. self.num_anchors_per_scale = 3
  324. self.num_attrib = 4 + 1 + self.config.num_classes
  325. self.ignore_threshold = 0.5
  326. self.lambda_coord = 1
  327. self.sigmoid = nn.Sigmoid()
  328. self.reshape = P.Reshape()
  329. self.tile = P.Tile()
  330. self.concat = P.Concat(axis=-1)
  331. self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32)
  332. def construct(self, x):
  333. num_batch = P.Shape()(x)[0]
  334. grid_size = P.Shape()(x)[2:4]
  335. # Reshape and transpose the feature to [n, 3, grid_size[0], grid_size[1], num_attrib]
  336. prediction = P.Reshape()(x, (num_batch,
  337. self.num_anchors_per_scale,
  338. self.num_attrib,
  339. grid_size[0],
  340. grid_size[1]))
  341. prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
  342. range_x = range(grid_size[1])
  343. range_y = range(grid_size[0])
  344. grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
  345. grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
  346. # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
  347. grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
  348. grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
  349. # Shape is [grid_size[0], grid_size[1], 1, 2]
  350. grid = self.concat((grid_x, grid_y))
  351. box_xy = prediction[:, :, :, :, :2]
  352. box_wh = prediction[:, :, :, :, 2:4]
  353. box_confidence = prediction[:, :, :, :, 4:5]
  354. box_probs = prediction[:, :, :, :, 5:]
  355. box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
  356. box_wh = P.Exp()(box_wh) * self.anchors / self.input_shape
  357. box_confidence = self.sigmoid(box_confidence)
  358. box_probs = self.sigmoid(box_probs)
  359. if self.training:
  360. return grid, prediction, box_xy, box_wh
  361. return box_xy, box_wh, box_confidence, box_probs
  362. class Iou(nn.Cell):
  363. """Calculate the iou of boxes."""
  364. def __init__(self):
  365. super(Iou, self).__init__()
  366. self.min = P.Minimum()
  367. self.max = P.Maximum()
  368. def construct(self, box1, box2):
  369. box1_xy = box1[:, :, :, :, :, :2]
  370. box1_wh = box1[:, :, :, :, :, 2:4]
  371. box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0)
  372. box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0)
  373. box2_xy = box2[:, :, :, :, :, :2]
  374. box2_wh = box2[:, :, :, :, :, 2:4]
  375. box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
  376. box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
  377. intersect_mins = self.max(box1_mins, box2_mins)
  378. intersect_maxs = self.min(box1_maxs, box2_maxs)
  379. intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
  380. intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
  381. P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
  382. box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
  383. box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
  384. iou = intersect_area / (box1_area + box2_area - intersect_area)
  385. return iou
  386. class YoloLossBlock(nn.Cell):
  387. """
  388. YOLOv3 Loss block cell. It will finally output loss of the scale.
  389. Args:
  390. scale (str): Three scale here, 's', 'm' and 'l'.
  391. config (Class): The default config of YOLOv3.
  392. Returns:
  393. Tensor, loss of the scale.
  394. Examples:
  395. YoloLossBlock('l', ConfigYOLOV3ResNet18()).
  396. """
  397. def __init__(self, scale, config):
  398. super(YoloLossBlock, self).__init__()
  399. self.config = config
  400. if scale == 's':
  401. idx = (0, 1, 2)
  402. elif scale == 'm':
  403. idx = (3, 4, 5)
  404. elif scale == 'l':
  405. idx = (6, 7, 8)
  406. else:
  407. raise KeyError("Invalid scale value for DetectionBlock")
  408. self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
  409. self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
  410. self.concat = P.Concat(axis=-1)
  411. self.iou = Iou()
  412. self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
  413. self.reduce_sum = P.ReduceSum()
  414. self.reduce_max = P.ReduceMax(keep_dims=False)
  415. self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32)
  416. def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box):
  417. object_mask = y_true[:, :, :, :, 4:5]
  418. class_probs = y_true[:, :, :, :, 5:]
  419. grid_shape = P.Shape()(prediction)[1:3]
  420. grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
  421. pred_boxes = self.concat((pred_xy, pred_wh))
  422. true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
  423. true_wh = y_true[:, :, :, :, 2:4]
  424. true_wh = P.Select()(P.Equal()(true_wh, 0.0),
  425. P.Fill()(P.DType()(true_wh), P.Shape()(true_wh), 1.0),
  426. true_wh)
  427. true_wh = P.Log()(true_wh / self.anchors * self.input_shape)
  428. box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
  429. gt_shape = P.Shape()(gt_box)
  430. gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
  431. iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) # [batch, grid[0], grid[1], num_anchor, num_gt]
  432. best_iou = self.reduce_max(iou, -1) # [batch, grid[0], grid[1], num_anchor]
  433. ignore_mask = best_iou < self.ignore_threshold
  434. ignore_mask = P.Cast()(ignore_mask, ms.float32)
  435. ignore_mask = P.ExpandDims()(ignore_mask, -1)
  436. ignore_mask = F.stop_gradient(ignore_mask)
  437. xy_loss = object_mask * box_loss_scale * self.cross_entropy(prediction[:, :, :, :, :2], true_xy)
  438. wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - prediction[:, :, :, :, 2:4])
  439. confidence_loss = self.cross_entropy(prediction[:, :, :, :, 4:5], object_mask)
  440. confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
  441. class_loss = object_mask * self.cross_entropy(prediction[:, :, :, :, 5:], class_probs)
  442. # Get smooth loss
  443. xy_loss = self.reduce_sum(xy_loss, ())
  444. wh_loss = self.reduce_sum(wh_loss, ())
  445. confidence_loss = self.reduce_sum(confidence_loss, ())
  446. class_loss = self.reduce_sum(class_loss, ())
  447. loss = xy_loss + wh_loss + confidence_loss + class_loss
  448. return loss / P.Shape()(prediction)[0]
  449. class yolov3_resnet18(nn.Cell):
  450. """
  451. ResNet based YOLOv3 network.
  452. Args:
  453. config (Class): YOLOv3 config.
  454. Returns:
  455. Cell, cell instance of ResNet based YOLOv3 neural network.
  456. Examples:
  457. yolov3_resnet18(80, [1,3,416,416]).
  458. """
  459. def __init__(self, config):
  460. super(yolov3_resnet18, self).__init__()
  461. self.config = config
  462. # YOLOv3 network
  463. self.feature_map = YOLOv3(feature_shape=self.config.feature_shape,
  464. backbone=ResNet(BasicBlock,
  465. self.config.backbone_layers,
  466. self.config.backbone_input_shape,
  467. self.config.backbone_shape,
  468. self.config.backbone_stride,
  469. num_classes=None),
  470. backbone_shape=self.config.backbone_shape,
  471. out_channel=self.config.out_channel)
  472. # prediction on the default anchor boxes
  473. self.detect_1 = DetectionBlock('l', self.config)
  474. self.detect_2 = DetectionBlock('m', self.config)
  475. self.detect_3 = DetectionBlock('s', self.config)
  476. def construct(self, x):
  477. big_object_output, medium_object_output, small_object_output = self.feature_map(x)
  478. output_big = self.detect_1(big_object_output)
  479. output_me = self.detect_2(medium_object_output)
  480. output_small = self.detect_3(small_object_output)
  481. return output_big, output_me, output_small
  482. class YoloWithLossCell(nn.Cell):
  483. """"
  484. Provide YOLOv3 training loss through network.
  485. Args:
  486. network (Cell): The training network.
  487. config (Class): YOLOv3 config.
  488. Returns:
  489. Tensor, the loss of the network.
  490. """
  491. def __init__(self, network, config):
  492. super(YoloWithLossCell, self).__init__()
  493. self.yolo_network = network
  494. self.config = config
  495. self.loss_big = YoloLossBlock('l', self.config)
  496. self.loss_me = YoloLossBlock('m', self.config)
  497. self.loss_small = YoloLossBlock('s', self.config)
  498. def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2):
  499. yolo_out = self.yolo_network(x)
  500. loss_l = self.loss_big(yolo_out[0][0], yolo_out[0][1], yolo_out[0][2], yolo_out[0][3], y_true_0, gt_0)
  501. loss_m = self.loss_me(yolo_out[1][0], yolo_out[1][1], yolo_out[1][2], yolo_out[1][3], y_true_1, gt_1)
  502. loss_s = self.loss_small(yolo_out[2][0], yolo_out[2][1], yolo_out[2][2], yolo_out[2][3], y_true_2, gt_2)
  503. return loss_l + loss_m + loss_s
  504. class TrainingWrapper(nn.Cell):
  505. """
  506. Encapsulation class of YOLOv3 network training.
  507. Append an optimizer to the training network after that the construct
  508. function can be called to create the backward graph.
  509. Args:
  510. network (Cell): The training network. Note that loss function should have been added.
  511. optimizer (Optimizer): Optimizer for updating the weights.
  512. sens (Number): The adjust parameter. Default: 1.0.
  513. """
  514. def __init__(self, network, optimizer, sens=1.0):
  515. super(TrainingWrapper, self).__init__(auto_prefix=False)
  516. self.network = network
  517. self.weights = ms.ParameterTuple(network.trainable_params())
  518. self.optimizer = optimizer
  519. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  520. self.sens = sens
  521. self.reducer_flag = False
  522. self.grad_reducer = None
  523. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  524. if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
  525. self.reducer_flag = True
  526. if self.reducer_flag:
  527. mean = context.get_auto_parallel_context("mirror_mean")
  528. if auto_parallel_context().get_device_num_is_set():
  529. degree = context.get_auto_parallel_context("device_num")
  530. else:
  531. degree = get_group_size()
  532. self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
  533. def construct(self, *args):
  534. weights = self.weights
  535. loss = self.network(*args)
  536. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  537. grads = self.grad(self.network, weights)(*args, sens)
  538. if self.reducer_flag:
  539. # apply grad reducer on grads
  540. grads = self.grad_reducer(grads)
  541. return F.depend(loss, self.optimizer(grads))
  542. class YoloBoxScores(nn.Cell):
  543. """
  544. Calculate the boxes of the original picture size and the score of each box.
  545. Args:
  546. config (Class): YOLOv3 config.
  547. Returns:
  548. Tensor, the boxes of the original picture size.
  549. Tensor, the score of each box.
  550. """
  551. def __init__(self, config):
  552. super(YoloBoxScores, self).__init__()
  553. self.input_shape = Tensor(np.array(config.img_shape), ms.float32)
  554. self.num_classes = config.num_classes
  555. def construct(self, box_xy, box_wh, box_confidence, box_probs, image_shape):
  556. batch_size = F.shape(box_xy)[0]
  557. x = box_xy[:, :, :, :, 0:1]
  558. y = box_xy[:, :, :, :, 1:2]
  559. box_yx = P.Concat(-1)((y, x))
  560. w = box_wh[:, :, :, :, 0:1]
  561. h = box_wh[:, :, :, :, 1:2]
  562. box_hw = P.Concat(-1)((h, w))
  563. new_shape = P.Round()(image_shape * P.ReduceMin()(self.input_shape / image_shape))
  564. offset = (self.input_shape - new_shape) / 2.0 / self.input_shape
  565. scale = self.input_shape / new_shape
  566. box_yx = (box_yx - offset) * scale
  567. box_hw = box_hw * scale
  568. box_min = box_yx - box_hw / 2.0
  569. box_max = box_yx + box_hw / 2.0
  570. boxes = P.Concat(-1)((box_min[:, :, :, :, 0:1],
  571. box_min[:, :, :, :, 1:2],
  572. box_max[:, :, :, :, 0:1],
  573. box_max[:, :, :, :, 1:2]))
  574. image_scale = P.Tile()(image_shape, (1, 2))
  575. boxes = boxes * image_scale
  576. boxes = F.reshape(boxes, (batch_size, -1, 4))
  577. boxes_scores = box_confidence * box_probs
  578. boxes_scores = F.reshape(boxes_scores, (batch_size, -1, self.num_classes))
  579. return boxes, boxes_scores
  580. class YoloWithEval(nn.Cell):
  581. """
  582. Encapsulation class of YOLOv3 evaluation.
  583. Args:
  584. network (Cell): The training network. Note that loss function and optimizer must not be added.
  585. config (Class): YOLOv3 config.
  586. Returns:
  587. Tensor, the boxes of the original picture size.
  588. Tensor, the score of each box.
  589. Tensor, the original picture size.
  590. """
  591. def __init__(self, network, config):
  592. super(YoloWithEval, self).__init__()
  593. self.yolo_network = network
  594. self.box_score_0 = YoloBoxScores(config)
  595. self.box_score_1 = YoloBoxScores(config)
  596. self.box_score_2 = YoloBoxScores(config)
  597. def construct(self, x, image_shape):
  598. yolo_output = self.yolo_network(x)
  599. boxes_0, boxes_scores_0 = self.box_score_0(*yolo_output[0], image_shape)
  600. boxes_1, boxes_scores_1 = self.box_score_1(*yolo_output[1], image_shape)
  601. boxes_2, boxes_scores_2 = self.box_score_2(*yolo_output[2], image_shape)
  602. boxes = P.Concat(1)((boxes_0, boxes_1, boxes_2))
  603. boxes_scores = P.Concat(1)((boxes_scores_0, boxes_scores_1, boxes_scores_2))
  604. return boxes, boxes_scores, image_shape